diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index fdcab6c247..327639f425 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -271,9 +271,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // remove all the other nodes. graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node); +#ifdef ENABLE_TRAINING // add two extra output defs, so we have 3 output defs that match what gradient builder expected layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean"), nullptr)); layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std_var"), nullptr)); +#endif modified = true; }