diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 4467b4027f..d54f145d40 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1522,6 +1522,28 @@ class Graph { }; #if !defined(ORT_MINIMAL_BUILD) +// Print NodeArg as +// name : type +// For example, +// "110": tensor(float) +std::ostream& operator<<(std::ostream& out, const NodeArg& node_arg); +// Print Node as, +// (operator's name, operator's type, domain, version) : (input0, input1, ...) -> (output0, output1, ...) +// For example, +// ("Add_14", Add, "", 7) : ("110": tensor(float),"109": tensor(float),) -> ("111": tensor(float),) +std::ostream& operator<<(std::ostream& out, const Node& node); +// Print Graph as, for example, +// Inputs: +// "Input": tensor(float) +// Nodes: +// ("add0", Add, "", 7) : ("Input": tensor(float),"Bias": tensor(float),) -> ("add0_out": tensor(float),) +// ("matmul", MatMul, "", 9) : ("add0_out": tensor(float),"matmul_weight": tensor(float),) -> ("matmul_out": tensor(float),) +// ("add1", Add, "", 7) : ("matmul_out": tensor(float),"add_weight": tensor(float),) -> ("add1_out": tensor(float),) +// ("reshape", Reshape, "", 5) : ("add1_out": tensor(float),"concat_out": tensor(int64),) -> ("Result": tensor(float),) +// Outputs: +// "Result": tensor(float) +// Inputs' and outputs' format is described in document of NodeArg's operator<< above. +// Node format is described in Node's operator<< above. std::ostream& operator<<(std::ostream& out, const Graph& graph); #endif diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 011ec52e7e..03acfec47f 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2517,7 +2517,8 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "This is an invalid model. Error in Node:", node_name, " : ", ex.what()); + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "This is an invalid model. In Node, ", node, ", Error ", ex.what()); }); } ORT_RETURN_IF_ERROR(status); @@ -4128,32 +4129,72 @@ Graph::~Graph() { } #if !defined(ORT_MINIMAL_BUILD) +std::ostream& operator<<(std::ostream& out, const NodeArg& node_arg) { + out << "\"" << node_arg.Name() << "\""; + if (node_arg.Type()) { + out << ": " << *node_arg.Type(); + } + return out; +} + +std::ostream& operator<<(std::ostream& out, const Node& node) { + out << "(\"" << node.Name() << "\"" + << ", " + << node.OpType() + << ", " + // Use quote so default ONNX domain is shown as "" + // rather than misleading empty string. + << "\"" << node.Domain() << "\"" + << ", " + << node.SinceVersion() + << ") : ("; + for (const auto* x : node.InputDefs()) { + if (x->Exists()) { + out << *x << ","; + } else { + // Print missing (or optional) inputs + // because operator schema uses positional + // arguments in ONNX. + out << "\"\"" + << ","; + } + } + out << ") -> ("; + for (const auto* x : node.OutputDefs()) { + if (x->Exists()) { + out << *x << ","; + } else { + // Print missing (or optional) outputs + // because operator schema uses positional + // arguments in ONNX. + out << "\"\"" + << ","; + } + } + out << ") "; + return out; +} + std::ostream& operator<<(std::ostream& out, const Graph& graph) { out << "Inputs:\n"; - for (auto* x : graph.GetInputs()) { - out << " " << x->Name() << " : " << *x->Type() << "\n"; + for (const auto* x : graph.GetInputs()) { + // Unlike we print missing input and output for operator, we don't + // print missing input for graph because they are not helpful (we + // don't have a fixed schema for graph to match arguments). + if (x) { + out << " " << *x << "\n"; + } } out << "Nodes:\n"; - for (auto& node : graph.Nodes()) { - out << " " << node.Name() << ": " << node.OpType() << " ("; - for (auto* x : node.InputDefs()) { - if (x->Exists()) { - out << x->Name() << ": " << *x->Type(); - } - out << ", "; - } - out << ") -> "; - for (auto* x : node.OutputDefs()) { - if (x->Exists()) { - out << x->Name() << ": " << *x->Type(); - } - out << ", "; - } - out << "\n"; + for (const auto& node : graph.Nodes()) { + out << " " << node << "\n"; } out << "Outputs:\n"; - for (auto* x : graph.GetOutputs()) { - out << " " << x->Name() << " : " << *x->Type() << "\n"; + for (const auto* x : graph.GetOutputs()) { + // Similar to graph input, missing graph output is not printed. + if (x) { + out << " " << *x << "\n"; + } } return out; }