hotfix for skiplayernorm (#3543)

Co-authored-by: Ethan Tao <ettao@microsoft.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
This commit is contained in:
ytaous 2020-04-17 01:22:08 -07:00 committed by GitHub
parent 92269ae409
commit fcb27c4e8b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -271,9 +271,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node);
#ifdef ENABLE_TRAINING
// add two extra output defs, so we have 3 output defs that match what gradient builder expected
layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean"), nullptr));
layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std_var"), nullptr));
#endif
modified = true;
}