mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
hotfix for skiplayernorm (#3543)
Co-authored-by: Ethan Tao <ettao@microsoft.com> Co-authored-by: Changming Sun <chasun@microsoft.com>
This commit is contained in:
parent
92269ae409
commit
fcb27c4e8b
1 changed files with 2 additions and 0 deletions
|
|
@ -271,9 +271,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// remove all the other nodes.
|
||||
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// add two extra output defs, so we have 3 output defs that match what gradient builder expected
|
||||
layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean"), nullptr));
|
||||
layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std_var"), nullptr));
|
||||
#endif
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue