diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4993bf7f36..e8350b5a6c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2049,6 +2049,8 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { node.op_ = nullptr; } + InitFunctionBodyForNode(node); + if (!node.op_) { return Status(ONNXRUNTIME, FAIL, "Fatal error: " + node.OpType() + " is not a registered function/op"); } diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index 09cc5a9fca..52ce76d8c8 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -602,13 +602,6 @@ Status TransformGraphForMixedPrecision(Graph& graph, LossSubgraph loss_subgraph(graph); // Stage 1: Convert whole graph including forward and backward to FP16 - // Initialize function body for all function nodes - // This is required to make sure after converting inputs\weights to FP16 - // the new NodeArg updates are correctly propagated to the function body nodes as well. - for (auto& node : graph.Nodes()) { - graph.InitFunctionBodyForNode(node); - } - // Insert Cast node to convert inputs from FP32 to FP16 // If all consumers are from loss graph, don't convert it, and remove it from To-32 loss graph inputs. for (const NodeArg* input : graph.GetInputs()) {