mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Construct valid graphs for ONNX checker for IR version < 4. (#9665)
* Construct valid graphs for ONNX checker for IR version < 4. Previously the constructed graph was not guaranteed to have its initializers be a subset of its inputs, which is required for IR version < 4. This resulted in spurious failures. Fixes #9663
This commit is contained in:
parent
32c896df6d
commit
93e239747f
3 changed files with 79 additions and 1 deletions
|
|
@ -99,7 +99,7 @@ class NodeArg {
|
|||
|
||||
#endif // !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
/** Gets this NodeArg as a ValueInfoProto. */
|
||||
/** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
|
||||
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
|
||||
|
||||
/** Gets a flag indicating whether this NodeArg exists or not.
|
||||
|
|
|
|||
|
|
@ -1104,6 +1104,12 @@ Graph::Graph(const Model& owning_model,
|
|||
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
|
||||
auto status = utils::ConstantNodeProtoToTensorProto(node, model_path, *tensor);
|
||||
ORT_ENFORCE(status.IsOK(), status.ToString());
|
||||
// Ensure initializers are also graph inputs.
|
||||
if (ir_version_ < 4) {
|
||||
TypeProto t{TypeProtoFromTensorProto(*tensor)};
|
||||
const NodeArg& node_arg = GetOrCreateNodeArg(tensor->name(), &t);
|
||||
*(graph_proto_->add_input()) = node_arg.ToProto();
|
||||
}
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
if (node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) {
|
||||
auto p = sparse_tensor_names_.emplace(tensor->name());
|
||||
|
|
|
|||
|
|
@ -1961,5 +1961,77 @@ TEST_F(GraphTest, DontRemoveUnusedInitializerWithGraphInput) {
|
|||
|
||||
ASSERT_NE(j, inputs_including_initializers.cend()) << "Unused initializer was incorrectly removed.";
|
||||
}
|
||||
|
||||
// The Model class internally:
|
||||
// 1. Converts the ONNX Model to an ORT Model
|
||||
// 2. Converts the ORT Model's Graph back to an ONNX Graph so that it
|
||||
// can run the ONNX Checker on it.
|
||||
// Previously this was buggy for models containing subgraphs with IR Version < 4.
|
||||
// We didn't always ensure that initializers are a subset of inputs, which is
|
||||
// required for IR version < 4.
|
||||
TEST_F(GraphTest, ConstantsBecomeInitializersAndInputs) {
|
||||
ModelProto m;
|
||||
m.set_ir_version(ONNX_NAMESPACE::IR_VERSION_2017_11_3);
|
||||
ImportOpset(m, "", 13);
|
||||
GraphProto* g = m.mutable_graph();
|
||||
g->set_name("test");
|
||||
|
||||
// Construct "output = if x: 2.0 else 1.0"
|
||||
ValueInfoProto* x = g->add_input();
|
||||
x->set_name("x");
|
||||
SetTypeAndShape(x->mutable_type()->mutable_tensor_type(), TensorProto_DataType_BOOL, {1});
|
||||
|
||||
NodeProto* if_node = g->add_node();
|
||||
if_node->set_op_type("If");
|
||||
if_node->set_name("If");
|
||||
if_node->add_input("x");
|
||||
if_node->add_output("output");
|
||||
|
||||
AttributeProto* if_attr = if_node->add_attribute();
|
||||
if_attr->set_name("then_branch");
|
||||
if_attr->set_type(AttributeProto_AttributeType_GRAPH);
|
||||
GraphProto* then_g = if_attr->mutable_g();
|
||||
then_g->set_name("then");
|
||||
ValueInfoProto* then_out = then_g->add_output();
|
||||
then_out->set_name("then_out");
|
||||
SetTypeAndShape(then_out->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1});
|
||||
NodeProto* two_node = then_g->add_node();
|
||||
two_node->set_op_type("Constant");
|
||||
AttributeProto* two_attr = two_node->add_attribute();
|
||||
two_attr->set_name("value");
|
||||
two_attr->set_type(AttributeProto_AttributeType_TENSOR);
|
||||
two_attr->mutable_t()->add_float_data(2.0);
|
||||
two_attr->mutable_t()->set_data_type(TensorProto_DataType_FLOAT);
|
||||
two_attr->mutable_t()->add_dims(1);
|
||||
two_node->set_name("Constant_two");
|
||||
two_node->add_output("then_out");
|
||||
|
||||
AttributeProto* else_attr = if_node->add_attribute();
|
||||
else_attr->set_name("else_branch");
|
||||
else_attr->set_type(AttributeProto_AttributeType_GRAPH);
|
||||
GraphProto* else_g = else_attr->mutable_g();
|
||||
else_g->set_name("else");
|
||||
ValueInfoProto* else_out = else_g->add_output();
|
||||
else_out->set_name("else_out");
|
||||
SetTypeAndShape(else_out->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1});
|
||||
NodeProto* one_node = else_g->add_node();
|
||||
one_node->set_op_type("Constant");
|
||||
AttributeProto* one_attr = one_node->add_attribute();
|
||||
one_attr->set_name("value");
|
||||
one_attr->set_type(AttributeProto_AttributeType_TENSOR);
|
||||
one_attr->mutable_t()->add_float_data(1.0);
|
||||
one_attr->mutable_t()->set_data_type(TensorProto_DataType_FLOAT);
|
||||
one_attr->mutable_t()->add_dims(1);
|
||||
one_node->set_name("Constant_one");
|
||||
one_node->add_output("else_out");
|
||||
|
||||
ValueInfoProto* output = g->add_output();
|
||||
output->set_name("output");
|
||||
SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), TensorProto_DataType_FLOAT, {1});
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
Status st = Model::Load(std::move(m), model, nullptr, *logger_);
|
||||
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue