From 38453acae3d90284e04813b8efc22f9cdada0fa8 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Tue, 8 Sep 2020 12:49:09 -0700 Subject: [PATCH] Further populate Stop Gradient list (#5021) * Add to Stop Gradient list * Improve Stop gradient --- .../core/framework/gradient_graph_builder.cc | 10 +++++++++- .../core/framework/gradient_graph_builder.h | 20 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index 8f0e01a6bb..6681e92daf 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -91,7 +91,8 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) { for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) { auto it = STOP_GRADIENT_EDGES.find(n->OpType()); if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) { - LOGS(logger_, WARNING) << "Skip building gradient for node: " << edge_it->GetNode().Name() ; + LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << n->Name(); continue; } @@ -159,6 +160,13 @@ Status GradientGraphBuilder::Build() { if (reachable_nodes.find(&next_node) == reachable_nodes.end()) continue; + auto it = STOP_GRADIENT_EDGES.find(next_node.OpType()); + if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) { + LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << next_node.Name(); + continue; + } + const NodeArg* node_arg = node->OutputDefs()[edge_it->GetSrcArgIndex()]; string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name()); diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 9bfbfaf28e..ae1f203c4e 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -27,9 +27,24 @@ typedef std::set NodeSet; static std::unordered_map> STOP_GRADIENT_EDGES = { - {"Pow", {1}}, + {"Not", {0}}, + {"And", {0, 1}}, + {"Or", {0, 1}}, + {"Xor", {0, 1}}, + {"Equal", {0, 1}}, + {"Less", {0, 1}}, + {"LessOrEqual", {0, 1}}, + {"Greater", {0, 1}}, + {"GreaterOrEqual", {0, 1}}, + {"IsInf", {0}}, + {"IsNaN", {0}}, + {"NonZero", {0}}, + {"Pow", {1}}, // TODO: Pow's input_1 is differentiable, but gradient not yet implemented {"Gather", {1}}, + {"GatherElements", {1}}, + {"GatherND", {1}}, {"Shape", {0}}, + {"Size", {0}}, {"Reshape", {1}}, {"Expand", {1}}, {"TrainableDropout", {1}}, @@ -39,10 +54,13 @@ static std::unordered_map> {"SoftmaxCrossEntropyLoss", {1, 2}}, {"ConstantOfShape", {0}}, {"Scatter", {1}}, + {"ScatterElements", {1}}, + {"ScatterND", {1}}, {"OneHot", {0, 1, 2}}, {"Where", {0}}, {"Range", {0, 1, 2}}, {"Tile", {1}}, + {"NonZero", {0}}, {"BroadcastGradientArgs", {0, 1}}}; class GradientGraphBuilder {