Improve print functions for NodeArg, Node, and Graph (#9801)

This commit is contained in:
Wei-Sheng Chin 2021-11-19 09:48:27 -08:00 committed by GitHub
parent 9d3c63263b
commit e520bb5145
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 21 deletions

View file

@ -1522,6 +1522,28 @@ class Graph {
};
#if !defined(ORT_MINIMAL_BUILD)
// Print NodeArg as
// name : type
// For example,
// "110": tensor(float)
std::ostream& operator<<(std::ostream& out, const NodeArg& node_arg);
// Print Node as,
// (operator's name, operator's type, domain, version) : (input0, input1, ...) -> (output0, output1, ...)
// For example,
// ("Add_14", Add, "", 7) : ("110": tensor(float),"109": tensor(float),) -> ("111": tensor(float),)
std::ostream& operator<<(std::ostream& out, const Node& node);
// Print Graph as, for example,
// Inputs:
// "Input": tensor(float)
// Nodes:
// ("add0", Add, "", 7) : ("Input": tensor(float),"Bias": tensor(float),) -> ("add0_out": tensor(float),)
// ("matmul", MatMul, "", 9) : ("add0_out": tensor(float),"matmul_weight": tensor(float),) -> ("matmul_out": tensor(float),)
// ("add1", Add, "", 7) : ("matmul_out": tensor(float),"add_weight": tensor(float),) -> ("add1_out": tensor(float),)
// ("reshape", Reshape, "", 5) : ("add1_out": tensor(float),"concat_out": tensor(int64),) -> ("Result": tensor(float),)
// Outputs:
// "Result": tensor(float)
// Inputs' and outputs' format is described in document of NodeArg's operator<< above.
// Node format is described in Node's operator<< above.
std::ostream& operator<<(std::ostream& out, const Graph& graph);
#endif

View file

@ -2517,7 +2517,8 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "This is an invalid model. Error in Node:", node_name, " : ", ex.what());
status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH,
"This is an invalid model. In Node, ", node, ", Error ", ex.what());
});
}
ORT_RETURN_IF_ERROR(status);
@ -4128,32 +4129,72 @@ Graph::~Graph() {
}
#if !defined(ORT_MINIMAL_BUILD)
std::ostream& operator<<(std::ostream& out, const NodeArg& node_arg) {
out << "\"" << node_arg.Name() << "\"";
if (node_arg.Type()) {
out << ": " << *node_arg.Type();
}
return out;
}
std::ostream& operator<<(std::ostream& out, const Node& node) {
out << "(\"" << node.Name() << "\""
<< ", "
<< node.OpType()
<< ", "
// Use quote so default ONNX domain is shown as ""
// rather than misleading empty string.
<< "\"" << node.Domain() << "\""
<< ", "
<< node.SinceVersion()
<< ") : (";
for (const auto* x : node.InputDefs()) {
if (x->Exists()) {
out << *x << ",";
} else {
// Print missing (or optional) inputs
// because operator schema uses positional
// arguments in ONNX.
out << "\"\""
<< ",";
}
}
out << ") -> (";
for (const auto* x : node.OutputDefs()) {
if (x->Exists()) {
out << *x << ",";
} else {
// Print missing (or optional) outputs
// because operator schema uses positional
// arguments in ONNX.
out << "\"\""
<< ",";
}
}
out << ") ";
return out;
}
std::ostream& operator<<(std::ostream& out, const Graph& graph) {
out << "Inputs:\n";
for (auto* x : graph.GetInputs()) {
out << " " << x->Name() << " : " << *x->Type() << "\n";
for (const auto* x : graph.GetInputs()) {
// Unlike we print missing input and output for operator, we don't
// print missing input for graph because they are not helpful (we
// don't have a fixed schema for graph to match arguments).
if (x) {
out << " " << *x << "\n";
}
}
out << "Nodes:\n";
for (auto& node : graph.Nodes()) {
out << " " << node.Name() << ": " << node.OpType() << " (";
for (auto* x : node.InputDefs()) {
if (x->Exists()) {
out << x->Name() << ": " << *x->Type();
}
out << ", ";
}
out << ") -> ";
for (auto* x : node.OutputDefs()) {
if (x->Exists()) {
out << x->Name() << ": " << *x->Type();
}
out << ", ";
}
out << "\n";
for (const auto& node : graph.Nodes()) {
out << " " << node << "\n";
}
out << "Outputs:\n";
for (auto* x : graph.GetOutputs()) {
out << " " << x->Name() << " : " << *x->Type() << "\n";
for (const auto* x : graph.GetOutputs()) {
// Similar to graph input, missing graph output is not printed.
if (x) {
out << " " << *x << "\n";
}
}
return out;
}