mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add test for If node with conditional branches only containing Constant nodes (#3949)
This commit is contained in:
parent
38467f8c9a
commit
9ef376880b
1 changed files with 74 additions and 0 deletions
|
|
@ -339,5 +339,79 @@ TEST(If, Opset11ThenAndElseBranchesProduceDifferentOutputShapes) {
|
|||
RunTest(false, options, false, OpTester::ExpectResult::kExpectSuccess, "", 11);
|
||||
}
|
||||
|
||||
// This is to test an "If" node with just "Constant" nodes in the "then" and "else" conditional branches
|
||||
class IfOpTesterOnlyConstantNodesInConditionalBranches : public OpTester {
|
||||
public:
|
||||
IfOpTesterOnlyConstantNodesInConditionalBranches() : OpTester("If") {
|
||||
}
|
||||
|
||||
protected:
|
||||
void AddNodes(onnxruntime::Graph& graph,
|
||||
std::vector<onnxruntime::NodeArg*>& graph_input_defs,
|
||||
std::vector<onnxruntime::NodeArg*>& graph_output_defs,
|
||||
std::vector<std::function<void(onnxruntime::Node& node)>>& /*add_attribute_funcs*/) override {
|
||||
// Graph inputs are 0:Cond for If
|
||||
ASSERT_EQ(graph_input_defs.size(), 1u);
|
||||
ASSERT_EQ(graph_output_defs.size(), 1u);
|
||||
|
||||
NodeArg* if_cond_input = graph_input_defs[0];
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
|
||||
// add If node
|
||||
{
|
||||
inputs = {if_cond_input};
|
||||
outputs = {graph_output_defs[0]};
|
||||
|
||||
auto& if_node = graph.AddNode("if", "If", "If node", inputs, outputs);
|
||||
|
||||
auto CreateSubgraphWithConstantNode = [](bool then_branch, float value, std::vector<NodeArg*> outputs) {
|
||||
Model model_then(then_branch ? "Then" : "Else", false, DefaultLoggingManager().DefaultLogger());
|
||||
auto& graph_then = model_then.MainGraph();
|
||||
auto& then_constant_node = graph_then.AddNode(
|
||||
then_branch ? "Constant_Then" : "Constant_Else",
|
||||
"Constant",
|
||||
then_branch ? "Constant_Then" : "Constant_Else", {}, outputs);
|
||||
|
||||
AttributeProto then_constant_attr_proto;
|
||||
then_constant_attr_proto.set_name("value");
|
||||
then_constant_attr_proto.set_type(AttributeProto_AttributeType_TENSOR);
|
||||
auto* then_constant_attr_tensor_proto = then_constant_attr_proto.mutable_t();
|
||||
then_constant_attr_tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
|
||||
then_constant_attr_tensor_proto->add_dims(1);
|
||||
then_constant_attr_tensor_proto->add_float_data(value); // Constant value of 10.f
|
||||
|
||||
then_constant_node.AddAttribute("value", then_constant_attr_proto);
|
||||
|
||||
auto status_then = graph_then.Resolve();
|
||||
EXPECT_EQ(status_then, Status::OK());
|
||||
|
||||
auto& graphproto_then = graph_then.ToGraphProto();
|
||||
return graphproto_then;
|
||||
};
|
||||
|
||||
if_node.AddAttribute("then_branch", CreateSubgraphWithConstantNode(true, 10.f, outputs));
|
||||
if_node.AddAttribute("else_branch", CreateSubgraphWithConstantNode(false, 1000.f, outputs));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Context: Github issue #3900
|
||||
TEST(If, ConditionalBranchesOnlyContainConstantNodes_ThenBranchExecution) {
|
||||
IfOpTesterOnlyConstantNodesInConditionalBranches test;
|
||||
test.AddInput<bool>("If_input", {1}, {true});
|
||||
test.AddOutput<float>("If_output", {1}, {10.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
// Context: Github issue #3900
|
||||
TEST(If, ConditionalBranchesOnlyContainConstantNodes_ElseBranchExecution) {
|
||||
IfOpTesterOnlyConstantNodesInConditionalBranches test;
|
||||
test.AddInput<bool>("If_input", {1}, {false});
|
||||
test.AddOutput<float>("If_output", {1}, {1000.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue