mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Fix loss function builder (#3801)
This commit is contained in:
parent
785b45124d
commit
2f8a2364c3
1 changed files with 5 additions and 0 deletions
|
|
@ -52,6 +52,7 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()(
|
|||
const std::string& prob_name = prediction_name + "_probability";
|
||||
|
||||
GraphAugmenter::GraphDefs graph_defs;
|
||||
graph_defs.AddGraphInputs({label_name});
|
||||
graph_defs.AddGraphOutputs({loss_name});
|
||||
std::vector<NodeDef> new_nodes;
|
||||
|
||||
|
|
@ -76,6 +77,8 @@ GraphAugmenter::GraphDefs SparseSoftmaxCrossEntropy::operator()(
|
|||
NodeAttributes(),
|
||||
"SoftmaxCrossEntropy" // name
|
||||
));
|
||||
|
||||
graph_defs.AddGraphInputs({weight_name});
|
||||
} else {
|
||||
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", // Op
|
||||
{ArgDef(prediction_name),
|
||||
|
|
@ -103,6 +106,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()(
|
|||
const std::string& prob_name = prediction_name + "_probability";
|
||||
|
||||
GraphAugmenter::GraphDefs graph_defs;
|
||||
graph_defs.AddGraphInputs({label_name});
|
||||
graph_defs.AddGraphOutputs({loss_name});
|
||||
std::vector<NodeDef> new_nodes;
|
||||
|
||||
|
|
@ -129,6 +133,7 @@ GraphAugmenter::GraphDefs SoftmaxCrossEntropyLoss::operator()(
|
|||
NodeAttributes(),
|
||||
"SoftmaxCrossEntropy" // name
|
||||
));
|
||||
graph_defs.AddGraphInputs({weight_name});
|
||||
} else {
|
||||
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", // Op
|
||||
{ArgDef(prediction_name),
|
||||
|
|
|
|||
Loading…
Reference in a new issue