mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Reshape inputs for SoftmaxCrossEntropyLoss instead of transposing them. (#4551)
This commit is contained in:
parent
bc1d197ddf
commit
9d80235607
1 changed files with 38 additions and 84 deletions
|
|
@ -47,49 +47,6 @@ TypeProto* BertLoss::GetGatheredPredictionTypeProto(const NodeArg* prediction_ar
|
|||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLoss::GetTransposedTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs) {
|
||||
ORT_ENFORCE(prediction_arg != nullptr, "GetTransposedTypeProto's prediction_arg is nullptr");
|
||||
const auto* logits_type_proto = prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
|
||||
TypeProto* type_proto = graph_defs.CreateTypeProto();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
||||
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
|
||||
// Batch size.
|
||||
target_shape->add_dim()->CopyFrom(dims[0]);
|
||||
|
||||
// Class.
|
||||
target_shape->add_dim()->CopyFrom(dims[dims.size() - 1]);
|
||||
|
||||
for (int i = 1; i < dims.size() - 1; i++) {
|
||||
target_shape->add_dim()->CopyFrom(dims[i]);
|
||||
}
|
||||
|
||||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLoss::GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs) {
|
||||
ORT_ENFORCE(prediction_arg != nullptr, "GetGatheredPredictionTransposedTypeProto's prediction_arg is nullptr");
|
||||
const auto* logits_type_proto = prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
|
||||
TypeProto* type_proto = graph_defs.CreateTypeProto();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
||||
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
|
||||
// Batch size.
|
||||
target_shape->add_dim()->CopyFrom(dims[0]);
|
||||
// Vocab size.
|
||||
target_shape->add_dim()->CopyFrom(dims[2]);
|
||||
// Prediction count.
|
||||
target_shape->add_dim()->set_dim_param("dynamic_prediction_count");
|
||||
|
||||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) {
|
||||
return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
}
|
||||
|
|
@ -108,7 +65,6 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
|
|||
|
||||
std::vector<NodeDef> new_nodes;
|
||||
GraphAugmenter::GraphDefs graph_defs;
|
||||
|
||||
// LabelSoftmaxCrossEntropy for masked_lm
|
||||
{
|
||||
const NodeArg* prediction_arg = graph.GetNodeArg(prediction_masked_lm);
|
||||
|
|
@ -130,30 +86,50 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
|
|||
{ArgDef("gathered_prediction", gathered_prediction_type_proto)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast<int64_t>(1))},
|
||||
"GATHERED_LM"));
|
||||
|
||||
ONNX_NAMESPACE::TensorProto t_proto;
|
||||
t_proto.add_dims(2);
|
||||
t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
t_proto.add_int64_data(static_cast<int64_t>(-1));
|
||||
t_proto.add_int64_data(prediction_arg->TypeAsProto()->tensor_type().shape().dim()[2].dim_value());
|
||||
new_nodes.emplace_back(NodeDef("Constant",
|
||||
{},
|
||||
{ArgDef("logit_reshape", nullptr)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)}));
|
||||
|
||||
// Transpose gathered_predictions with the following permutation: {0, 2, 1} because SoftmaxCrossEntropyLoss accepts
|
||||
// scores of the shape {N, C, D1, D2...Dk}.
|
||||
TypeProto* gathered_prediction_transposed_type_proto = GetGatheredPredictionTransposedTypeProto(prediction_arg,
|
||||
graph_defs);
|
||||
new_nodes.emplace_back(NodeDef("Reshape",
|
||||
{ArgDef("gathered_prediction", gathered_prediction_type_proto),
|
||||
ArgDef("logit_reshape")}, // Inputs
|
||||
{ArgDef("gathered_prediction_reshaped")}, // Outputs
|
||||
NodeAttributes(),
|
||||
"Reshape_gathered_prediction"));
|
||||
|
||||
new_nodes.emplace_back(NodeDef("Transpose",
|
||||
{ArgDef("gathered_prediction", gathered_prediction_type_proto)}, // Inputs
|
||||
{ArgDef("gathered_prediction_transposed",
|
||||
gathered_prediction_transposed_type_proto)}, // Outputs
|
||||
{ONNX_NAMESPACE::MakeAttribute("perm", std::vector<int64_t>{static_cast<int64_t>(0),
|
||||
static_cast<int64_t>(2), static_cast<int64_t>(1)})},
|
||||
ONNX_NAMESPACE::TensorProto t_proto_label;
|
||||
t_proto_label.add_dims(1);
|
||||
t_proto_label.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
t_proto_label.add_int64_data(static_cast<int64_t>(-1));
|
||||
|
||||
"Transpose_gathered_prediction"));
|
||||
new_nodes.emplace_back(NodeDef("Constant",
|
||||
{},
|
||||
{ArgDef("label_reshape", nullptr)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("value", t_proto_label)}));
|
||||
|
||||
new_nodes.emplace_back(NodeDef("Reshape",
|
||||
{ArgDef("masked_lm_ids", masked_lm_int64_type_proto),
|
||||
ArgDef("label_reshape")}, // Inputs
|
||||
{ArgDef("masked_lm_ids_reshaped")}, // Outputs
|
||||
NodeAttributes(),
|
||||
"Reshape_label"));
|
||||
|
||||
std::vector<AttributeProto> attrs;
|
||||
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("ignore_index", static_cast<int64_t>(-1)));
|
||||
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("reduction", "mean"));
|
||||
|
||||
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
|
||||
{ArgDef("gathered_prediction_transposed", gathered_prediction_transposed_type_proto),
|
||||
ArgDef(masked_lm_ids, masked_lm_int64_type_proto)}, // Inputs
|
||||
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
|
||||
ArgDef("probability_lm", gathered_prediction_transposed_type_proto)},
|
||||
{ArgDef("gathered_prediction_reshaped"),
|
||||
ArgDef("masked_lm_ids_reshaped")}, // Inputs
|
||||
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
|
||||
ArgDef("probability_lm")},
|
||||
attrs,
|
||||
"Masked_LM_Loss"));
|
||||
}
|
||||
|
|
@ -167,33 +143,11 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
|
|||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
graph_defs);
|
||||
|
||||
// Transpose prediction_next_sentence with the following permutation: {0, n-1, 1, 2....n-2} because
|
||||
// SoftmaxCrossEntropyLoss accepts scores of the shape {N, C, D1, D2...Dk}.
|
||||
|
||||
TypeProto* prediction_next_sentence_transposed_type_proto = GetTransposedTypeProto(ns_prediction_arg, graph_defs);
|
||||
const auto* logits_type_proto = ns_prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
std::vector<int64_t> prediction_next_sentence_transposed_perm;
|
||||
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(0));
|
||||
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(dims.size() - 1));
|
||||
|
||||
for (int i = 1; i < dims.size() - 1; i++) {
|
||||
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(i));
|
||||
}
|
||||
|
||||
new_nodes.emplace_back(NodeDef("Transpose",
|
||||
{ArgDef(prediction_next_sentence)}, // Inputs
|
||||
{ArgDef("prediction_next_sentence_transposed",
|
||||
prediction_next_sentence_transposed_type_proto)}, // Outputs
|
||||
{ONNX_NAMESPACE::MakeAttribute("perm", prediction_next_sentence_transposed_perm)},
|
||||
|
||||
"Transpose_prediction_next_sentence"));
|
||||
|
||||
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
|
||||
{ArgDef("prediction_next_sentence_transposed", prediction_next_sentence_transposed_type_proto),
|
||||
ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs
|
||||
{ArgDef(prediction_next_sentence),
|
||||
ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs
|
||||
{ArgDef(nsp_loss, GetLossTypeProto(graph_defs)),
|
||||
ArgDef("probability_ns", prediction_next_sentence_transposed_type_proto)}, // Outputs
|
||||
ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs
|
||||
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
|
||||
"Next_Sentence_Loss"));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue