Construct valid graphs for ONNX checker for IR version < 4. (#9665)

* Construct valid graphs for ONNX checker for IR version < 4.

Previously the constructed graph was not guaranteed to have its
initializers be a subset of its inputs, which is required for IR
version < 4. This resulted in spurious failures.

Fixes #9663
This commit is contained in:
Gary Miguel 2021-11-11 15:13:28 -08:00 committed by GitHub
parent 32c896df6d
commit 93e239747f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 1 deletions

View file

@ -99,7 +99,7 @@ class NodeArg {
#endif // !defined(ORT_MINIMAL_BUILD)
/** Gets this NodeArg as a ValueInfoProto. */
/** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
/** Gets a flag indicating whether this NodeArg exists or not.

View file

@ -1104,6 +1104,12 @@ Graph::Graph(const Model& owning_model,
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
auto status = utils::ConstantNodeProtoToTensorProto(node, model_path, *tensor);
ORT_ENFORCE(status.IsOK(), status.ToString());
// Ensure initializers are also graph inputs.
if (ir_version_ < 4) {
TypeProto t{TypeProtoFromTensorProto(*tensor)};
const NodeArg& node_arg = GetOrCreateNodeArg(tensor->name(), &t);
*(graph_proto_->add_input()) = node_arg.ToProto();
}
#if !defined(DISABLE_SPARSE_TENSORS)
if (node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) {
auto p = sparse_tensor_names_.emplace(tensor->name());

View file

@ -1961,5 +1961,77 @@ TEST_F(GraphTest, DontRemoveUnusedInitializerWithGraphInput) {
ASSERT_NE(j, inputs_including_initializers.cend()) << "Unused initializer was incorrectly removed.";
}
// The Model class internally:
// 1. Converts the ONNX Model to an ORT Model
// 2. Converts the ORT Model's Graph back to an ONNX Graph so that it
// can run the ONNX Checker on it.
// Previously this was buggy for models containing subgraphs with IR Version < 4.
// We didn't always ensure that initializers are a subset of inputs, which is
// required for IR version < 4.
TEST_F(GraphTest, ConstantsBecomeInitializersAndInputs) {
ModelProto m;
m.set_ir_version(ONNX_NAMESPACE::IR_VERSION_2017_11_3);
ImportOpset(m, "", 13);
GraphProto* g = m.mutable_graph();
g->set_name("test");
// Construct "output = if x: 2.0 else 1.0"
ValueInfoProto* x = g->add_input();
x->set_name("x");
SetTypeAndShape(x->mutable_type()->mutable_tensor_type(), TensorProto_DataType_BOOL, {1});
NodeProto* if_node = g->add_node();
if_node->set_op_type("If");
if_node->set_name("If");
if_node->add_input("x");
if_node->add_output("output");
AttributeProto* if_attr = if_node->add_attribute();
if_attr->set_name("then_branch");
if_attr->set_type(AttributeProto_AttributeType_GRAPH);
GraphProto* then_g = if_attr->mutable_g();
then_g->set_name("then");
ValueInfoProto* then_out = then_g->add_output();
then_out->set_name("then_out");
SetTypeAndShape(then_out->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1});
NodeProto* two_node = then_g->add_node();
two_node->set_op_type("Constant");
AttributeProto* two_attr = two_node->add_attribute();
two_attr->set_name("value");
two_attr->set_type(AttributeProto_AttributeType_TENSOR);
two_attr->mutable_t()->add_float_data(2.0);
two_attr->mutable_t()->set_data_type(TensorProto_DataType_FLOAT);
two_attr->mutable_t()->add_dims(1);
two_node->set_name("Constant_two");
two_node->add_output("then_out");
AttributeProto* else_attr = if_node->add_attribute();
else_attr->set_name("else_branch");
else_attr->set_type(AttributeProto_AttributeType_GRAPH);
GraphProto* else_g = else_attr->mutable_g();
else_g->set_name("else");
ValueInfoProto* else_out = else_g->add_output();
else_out->set_name("else_out");
SetTypeAndShape(else_out->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1});
NodeProto* one_node = else_g->add_node();
one_node->set_op_type("Constant");
AttributeProto* one_attr = one_node->add_attribute();
one_attr->set_name("value");
one_attr->set_type(AttributeProto_AttributeType_TENSOR);
one_attr->mutable_t()->add_float_data(1.0);
one_attr->mutable_t()->set_data_type(TensorProto_DataType_FLOAT);
one_attr->mutable_t()->add_dims(1);
one_node->set_name("Constant_one");
one_node->add_output("else_out");
ValueInfoProto* output = g->add_output();
output->set_name("output");
SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1});
std::shared_ptr<Model> model;
Status st = Model::Load(std::move(m), model, nullptr, *logger_);
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
}
} // namespace test
} // namespace onnxruntime