Fix loss function builder (#3801)

This commit is contained in:
Sherlock 2020-05-04 10:41:15 -07:00 committed by GitHub
parent 785b45124d
commit 2f8a2364c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -52,6 +52,7 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()(
const std::string& prob_name = prediction_name + "_probability";
GraphAugmenter::GraphDefs graph_defs;
graph_defs.AddGraphInputs({label_name});
graph_defs.AddGraphOutputs({loss_name});
std::vector<NodeDef> new_nodes;
@ -76,6 +77,8 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()(
NodeAttributes(),
"SoftmaxCrossEntropy" // name
));
graph_defs.AddGraphInputs({weight_name});
} else {
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", // Op
{ArgDef(prediction_name),
@ -103,6 +106,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()(
const std::string& prob_name = prediction_name + "_probability";
GraphAugmenter::GraphDefs graph_defs;
graph_defs.AddGraphInputs({label_name});
graph_defs.AddGraphOutputs({loss_name});
std::vector<NodeDef> new_nodes;
@ -129,6 +133,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()(
NodeAttributes(),
"SoftmaxCrossEntropy" // name
));
graph_defs.AddGraphInputs({weight_name});
} else {
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", // Op
{ArgDef(prediction_name),