From 2f8a2364c351d7a75dce01617d016095bc0a1b71 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Mon, 4 May 2020 10:41:15 -0700 Subject: [PATCH] Fix loss function builder (#3801) --- .../core/graph/loss_func/softmax_cross_entropy.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc b/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc index b03f095558..a30b3964fa 100644 --- a/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc +++ b/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc @@ -52,6 +52,7 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()( const std::string& prob_name = prediction_name + "_probability"; GraphAugmenter::GraphDefs graph_defs; + graph_defs.AddGraphInputs({label_name}); graph_defs.AddGraphOutputs({loss_name}); std::vector new_nodes; @@ -76,6 +77,8 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()( NodeAttributes(), "SoftmaxCrossEntropy" // name )); + + graph_defs.AddGraphInputs({weight_name}); } else { new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", // Op {ArgDef(prediction_name), @@ -103,6 +106,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()( const std::string& prob_name = prediction_name + "_probability"; GraphAugmenter::GraphDefs graph_defs; + graph_defs.AddGraphInputs({label_name}); graph_defs.AddGraphOutputs({loss_name}); std::vector new_nodes; @@ -129,6 +133,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()( NodeAttributes(), "SoftmaxCrossEntropy" // name )); + graph_defs.AddGraphInputs({weight_name}); } else { new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", // Op {ArgDef(prediction_name),