From dd73e8c016a7b8037818d8f54aafb2df2ffb13b5 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Mon, 6 Jul 2020 15:17:27 -0700 Subject: [PATCH] add function initialization back to graph resolve (#4434) --- onnxruntime/core/graph/graph.cc | 2 ++ .../orttraining/core/graph/mixed_precision_transformer.cc | 7 ------- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4993bf7f36..e8350b5a6c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2049,6 +2049,8 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { node.op_ = nullptr; } + InitFunctionBodyForNode(node); + if (!node.op_) { return Status(ONNXRUNTIME, FAIL, "Fatal error: " + node.OpType() + " is not a registered function/op"); } diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index 09cc5a9fca..52ce76d8c8 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -602,13 +602,6 @@ Status TransformGraphForMixedPrecision(Graph& graph, LossSubgraph loss_subgraph(graph); // Stage 1: Convert whole graph including forward and backward to FP16 - // Initialize function body for all function nodes - // This is required to make sure after converting inputs\weights to FP16 - // the new NodeArg updates are correctly propagated to the function body nodes as well. - for (auto& node : graph.Nodes()) { - graph.InitFunctionBodyForNode(node); - } - // Insert Cast node to convert inputs from FP32 to FP16 // If all consumers are from loss graph, don't convert it, and remove it from To-32 loss graph inputs. for (const NodeArg* input : graph.GetInputs()) {