Further populate Stop Gradient list (#5021)

* Add to Stop Gradient list

* Improve Stop gradient
This commit is contained in:
Sherlock 2020-09-08 12:49:09 -07:00 committed by GitHub
parent e1ed0fde2b
commit 38453acae3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 2 deletions

View file

@ -91,7 +91,8 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) {
auto it = STOP_GRADIENT_EDGES.find(n->OpType());
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
LOGS(logger_, WARNING) << "Skip building gradient for node: " << edge_it->GetNode().Name() ;
LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << n->Name();
continue;
}
@ -159,6 +160,13 @@ Status GradientGraphBuilder::Build() {
if (reachable_nodes.find(&next_node) == reachable_nodes.end()) continue;
auto it = STOP_GRADIENT_EDGES.find(next_node.OpType());
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
<< " of node: " << next_node.Name();
continue;
}
const NodeArg* node_arg = node->OutputDefs()[edge_it->GetSrcArgIndex()];
string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name());

View file

@ -27,9 +27,24 @@ typedef std::set<const Node*, NodeCompare> NodeSet;
static std::unordered_map<std::string, std::unordered_set<size_t>>
STOP_GRADIENT_EDGES = {
{"Pow", {1}},
{"Not", {0}},
{"And", {0, 1}},
{"Or", {0, 1}},
{"Xor", {0, 1}},
{"Equal", {0, 1}},
{"Less", {0, 1}},
{"LessOrEqual", {0, 1}},
{"Greater", {0, 1}},
{"GreaterOrEqual", {0, 1}},
{"IsInf", {0}},
{"IsNaN", {0}},
{"NonZero", {0}},
{"Pow", {1}}, // TODO: Pow's input_1 is differentiable, but gradient not yet implemented
{"Gather", {1}},
{"GatherElements", {1}},
{"GatherND", {1}},
{"Shape", {0}},
{"Size", {0}},
{"Reshape", {1}},
{"Expand", {1}},
{"TrainableDropout", {1}},
@ -39,10 +54,13 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
{"SoftmaxCrossEntropyLoss", {1, 2}},
{"ConstantOfShape", {0}},
{"Scatter", {1}},
{"ScatterElements", {1}},
{"ScatterND", {1}},
{"OneHot", {0, 1, 2}},
{"Where", {0}},
{"Range", {0, 1, 2}},
{"Tile", {1}},
{"NonZero", {0}},
{"BroadcastGradientArgs", {0, 1}}};
class GradientGraphBuilder {