diff --git a/onnxruntime/test/providers/cpu/controlflow/if_test.cc b/onnxruntime/test/providers/cpu/controlflow/if_test.cc index 2dd103ede6..83ced88d6a 100644 --- a/onnxruntime/test/providers/cpu/controlflow/if_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/if_test.cc @@ -339,5 +339,79 @@ TEST(If, Opset11ThenAndElseBranchesProduceDifferentOutputShapes) { RunTest(false, options, false, OpTester::ExpectResult::kExpectSuccess, "", 11); } +// This is to test an "If" node with just "Constant" nodes in the "then" and "else" conditional branches +class IfOpTesterOnlyConstantNodesInConditionalBranches : public OpTester { + public: + IfOpTesterOnlyConstantNodesInConditionalBranches() : OpTester("If") { + } + + protected: + void AddNodes(onnxruntime::Graph& graph, + std::vector& graph_input_defs, + std::vector& graph_output_defs, + std::vector>& /*add_attribute_funcs*/) override { + // Graph inputs are 0:Cond for If + ASSERT_EQ(graph_input_defs.size(), 1u); + ASSERT_EQ(graph_output_defs.size(), 1u); + + NodeArg* if_cond_input = graph_input_defs[0]; + + std::vector inputs; + std::vector outputs; + + // add If node + { + inputs = {if_cond_input}; + outputs = {graph_output_defs[0]}; + + auto& if_node = graph.AddNode("if", "If", "If node", inputs, outputs); + + auto CreateSubgraphWithConstantNode = [](bool then_branch, float value, std::vector outputs) { + Model model_then(then_branch ? "Then" : "Else", false, DefaultLoggingManager().DefaultLogger()); + auto& graph_then = model_then.MainGraph(); + auto& then_constant_node = graph_then.AddNode( + then_branch ? "Constant_Then" : "Constant_Else", + "Constant", + then_branch ? "Constant_Then" : "Constant_Else", {}, outputs); + + AttributeProto then_constant_attr_proto; + then_constant_attr_proto.set_name("value"); + then_constant_attr_proto.set_type(AttributeProto_AttributeType_TENSOR); + auto* then_constant_attr_tensor_proto = then_constant_attr_proto.mutable_t(); + then_constant_attr_tensor_proto->set_data_type(TensorProto_DataType_FLOAT); + then_constant_attr_tensor_proto->add_dims(1); + then_constant_attr_tensor_proto->add_float_data(value); // Constant value of 10.f + + then_constant_node.AddAttribute("value", then_constant_attr_proto); + + auto status_then = graph_then.Resolve(); + EXPECT_EQ(status_then, Status::OK()); + + auto& graphproto_then = graph_then.ToGraphProto(); + return graphproto_then; + }; + + if_node.AddAttribute("then_branch", CreateSubgraphWithConstantNode(true, 10.f, outputs)); + if_node.AddAttribute("else_branch", CreateSubgraphWithConstantNode(false, 1000.f, outputs)); + } + } +}; + +// Context: Github issue #3900 +TEST(If, ConditionalBranchesOnlyContainConstantNodes_ThenBranchExecution) { + IfOpTesterOnlyConstantNodesInConditionalBranches test; + test.AddInput("If_input", {1}, {true}); + test.AddOutput("If_output", {1}, {10.f}); + test.Run(); +} + +// Context: Github issue #3900 +TEST(If, ConditionalBranchesOnlyContainConstantNodes_ElseBranchExecution) { + IfOpTesterOnlyConstantNodesInConditionalBranches test; + test.AddInput("If_input", {1}, {false}); + test.AddOutput("If_output", {1}, {1000.f}); + test.Run(); +} + } // namespace test } // namespace onnxruntime