Add test for If node with conditional branches only containing Constant nodes (#3949)

This commit is contained in:
Hariharan Seshadri 2020-05-14 19:21:40 -07:00 committed by GitHub
parent 38467f8c9a
commit 9ef376880b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -339,5 +339,79 @@ TEST(If, Opset11ThenAndElseBranchesProduceDifferentOutputShapes) {
RunTest(false, options, false, OpTester::ExpectResult::kExpectSuccess, "", 11);
}
// This is to test an "If" node with just "Constant" nodes in the "then" and "else" conditional branches
class IfOpTesterOnlyConstantNodesInConditionalBranches : public OpTester {
public:
IfOpTesterOnlyConstantNodesInConditionalBranches() : OpTester("If") {
}
protected:
void AddNodes(onnxruntime::Graph& graph,
std::vector<onnxruntime::NodeArg*>& graph_input_defs,
std::vector<onnxruntime::NodeArg*>& graph_output_defs,
std::vector<std::function<void(onnxruntime::Node& node)>>& /*add_attribute_funcs*/) override {
// Graph inputs are 0:Cond for If
ASSERT_EQ(graph_input_defs.size(), 1u);
ASSERT_EQ(graph_output_defs.size(), 1u);
NodeArg* if_cond_input = graph_input_defs[0];
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;
// add If node
{
inputs = {if_cond_input};
outputs = {graph_output_defs[0]};
auto& if_node = graph.AddNode("if", "If", "If node", inputs, outputs);
auto CreateSubgraphWithConstantNode = [](bool then_branch, float value, std::vector<NodeArg*> outputs) {
Model model_then(then_branch ? "Then" : "Else", false, DefaultLoggingManager().DefaultLogger());
auto& graph_then = model_then.MainGraph();
auto& then_constant_node = graph_then.AddNode(
then_branch ? "Constant_Then" : "Constant_Else",
"Constant",
then_branch ? "Constant_Then" : "Constant_Else", {}, outputs);
AttributeProto then_constant_attr_proto;
then_constant_attr_proto.set_name("value");
then_constant_attr_proto.set_type(AttributeProto_AttributeType_TENSOR);
auto* then_constant_attr_tensor_proto = then_constant_attr_proto.mutable_t();
then_constant_attr_tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
then_constant_attr_tensor_proto->add_dims(1);
then_constant_attr_tensor_proto->add_float_data(value); // Constant value of 10.f
then_constant_node.AddAttribute("value", then_constant_attr_proto);
auto status_then = graph_then.Resolve();
EXPECT_EQ(status_then, Status::OK());
auto& graphproto_then = graph_then.ToGraphProto();
return graphproto_then;
};
if_node.AddAttribute("then_branch", CreateSubgraphWithConstantNode(true, 10.f, outputs));
if_node.AddAttribute("else_branch", CreateSubgraphWithConstantNode(false, 1000.f, outputs));
}
}
};
// Context: Github issue #3900
TEST(If, ConditionalBranchesOnlyContainConstantNodes_ThenBranchExecution) {
IfOpTesterOnlyConstantNodesInConditionalBranches test;
test.AddInput<bool>("If_input", {1}, {true});
test.AddOutput<float>("If_output", {1}, {10.f});
test.Run();
}
// Context: Github issue #3900
TEST(If, ConditionalBranchesOnlyContainConstantNodes_ElseBranchExecution) {
IfOpTesterOnlyConstantNodesInConditionalBranches test;
test.AddInput<bool>("If_input", {1}, {false});
test.AddOutput<float>("If_output", {1}, {1000.f});
test.Run();
}
} // namespace test
} // namespace onnxruntime