diff --git a/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc b/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc index b03f095558..a30b3964fa 100644 --- a/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc +++ b/orttraining/orttraining/core/graph/loss_func/softmax_cross_entropy.cc @@ -52,6 +52,7 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()( const std::string& prob_name = prediction_name + "_probability"; GraphAugmenter::GraphDefs graph_defs; + graph_defs.AddGraphInputs({label_name}); graph_defs.AddGraphOutputs({loss_name}); std::vector new_nodes; @@ -76,6 +77,8 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()( NodeAttributes(), "SoftmaxCrossEntropy" // name )); + + graph_defs.AddGraphInputs({weight_name}); } else { new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", // Op {ArgDef(prediction_name), @@ -103,6 +106,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()( const std::string& prob_name = prediction_name + "_probability"; GraphAugmenter::GraphDefs graph_defs; + graph_defs.AddGraphInputs({label_name}); graph_defs.AddGraphOutputs({loss_name}); std::vector new_nodes; @@ -129,6 +133,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()( NodeAttributes(), "SoftmaxCrossEntropy" // name )); + graph_defs.AddGraphInputs({weight_name}); } else { new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", // Op {ArgDef(prediction_name),