mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Further populate Stop Gradient list (#5021)
* Add to Stop Gradient list * Improve Stop gradient
This commit is contained in:
parent
e1ed0fde2b
commit
38453acae3
2 changed files with 28 additions and 2 deletions
|
|
@ -91,7 +91,8 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
|
|||
for (auto edge_it = n->InputEdgesBegin(); edge_it != n->InputEdgesEnd(); ++edge_it) {
|
||||
auto it = STOP_GRADIENT_EDGES.find(n->OpType());
|
||||
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
|
||||
LOGS(logger_, WARNING) << "Skip building gradient for node: " << edge_it->GetNode().Name() ;
|
||||
LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << n->Name();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -159,6 +160,13 @@ Status GradientGraphBuilder::Build() {
|
|||
|
||||
if (reachable_nodes.find(&next_node) == reachable_nodes.end()) continue;
|
||||
|
||||
auto it = STOP_GRADIENT_EDGES.find(next_node.OpType());
|
||||
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
|
||||
LOGS(logger_, WARNING) << "Skip building gradient for input_" << edge_it->GetDstArgIndex()
|
||||
<< " of node: " << next_node.Name();
|
||||
continue;
|
||||
}
|
||||
|
||||
const NodeArg* node_arg = node->OutputDefs()[edge_it->GetSrcArgIndex()];
|
||||
string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name());
|
||||
|
||||
|
|
|
|||
|
|
@ -27,9 +27,24 @@ typedef std::set<const Node*, NodeCompare> NodeSet;
|
|||
|
||||
static std::unordered_map<std::string, std::unordered_set<size_t>>
|
||||
STOP_GRADIENT_EDGES = {
|
||||
{"Pow", {1}},
|
||||
{"Not", {0}},
|
||||
{"And", {0, 1}},
|
||||
{"Or", {0, 1}},
|
||||
{"Xor", {0, 1}},
|
||||
{"Equal", {0, 1}},
|
||||
{"Less", {0, 1}},
|
||||
{"LessOrEqual", {0, 1}},
|
||||
{"Greater", {0, 1}},
|
||||
{"GreaterOrEqual", {0, 1}},
|
||||
{"IsInf", {0}},
|
||||
{"IsNaN", {0}},
|
||||
{"NonZero", {0}},
|
||||
{"Pow", {1}}, // TODO: Pow's input_1 is differentiable, but gradient not yet implemented
|
||||
{"Gather", {1}},
|
||||
{"GatherElements", {1}},
|
||||
{"GatherND", {1}},
|
||||
{"Shape", {0}},
|
||||
{"Size", {0}},
|
||||
{"Reshape", {1}},
|
||||
{"Expand", {1}},
|
||||
{"TrainableDropout", {1}},
|
||||
|
|
@ -39,10 +54,13 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
|
|||
{"SoftmaxCrossEntropyLoss", {1, 2}},
|
||||
{"ConstantOfShape", {0}},
|
||||
{"Scatter", {1}},
|
||||
{"ScatterElements", {1}},
|
||||
{"ScatterND", {1}},
|
||||
{"OneHot", {0, 1, 2}},
|
||||
{"Where", {0}},
|
||||
{"Range", {0, 1, 2}},
|
||||
{"Tile", {1}},
|
||||
{"NonZero", {0}},
|
||||
{"BroadcastGradientArgs", {0, 1}}};
|
||||
|
||||
class GradientGraphBuilder {
|
||||
|
|
|
|||
Loading…
Reference in a new issue