Reshape inputs for SoftmaxCrossEntropyLoss instead of transposing them. (#4551)

This commit is contained in:
M. Zeeshan Siddiqui 2020-07-20 06:33:40 -07:00 committed by GitHub
parent bc1d197ddf
commit 9d80235607
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -47,49 +47,6 @@ TypeProto* BertLoss::GetGatheredPredictionTypeProto(const NodeArg* prediction_ar
return type_proto;
}
TypeProto* BertLoss::GetTransposedTypeProto(const NodeArg* prediction_arg,
GraphAugmenter::GraphDefs& graph_defs) {
ORT_ENFORCE(prediction_arg != nullptr, "GetTransposedTypeProto's prediction_arg is nullptr");
const auto* logits_type_proto = prediction_arg->TypeAsProto();
const auto& dims = logits_type_proto->tensor_type().shape().dim();
TypeProto* type_proto = graph_defs.CreateTypeProto();
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
// Batch size.
target_shape->add_dim()->CopyFrom(dims[0]);
// Class.
target_shape->add_dim()->CopyFrom(dims[dims.size() - 1]);
for (int i = 1; i < dims.size() - 1; i++) {
target_shape->add_dim()->CopyFrom(dims[i]);
}
return type_proto;
}
TypeProto* BertLoss::GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg,
GraphAugmenter::GraphDefs& graph_defs) {
ORT_ENFORCE(prediction_arg != nullptr, "GetGatheredPredictionTransposedTypeProto's prediction_arg is nullptr");
const auto* logits_type_proto = prediction_arg->TypeAsProto();
const auto& dims = logits_type_proto->tensor_type().shape().dim();
TypeProto* type_proto = graph_defs.CreateTypeProto();
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
// Batch size.
target_shape->add_dim()->CopyFrom(dims[0]);
// Vocab size.
target_shape->add_dim()->CopyFrom(dims[2]);
// Prediction count.
target_shape->add_dim()->set_dim_param("dynamic_prediction_count");
return type_proto;
}
TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) {
return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
}
@ -108,7 +65,6 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
std::vector<NodeDef> new_nodes;
GraphAugmenter::GraphDefs graph_defs;
// LabelSoftmaxCrossEntropy for masked_lm
{
const NodeArg* prediction_arg = graph.GetNodeArg(prediction_masked_lm);
@ -130,30 +86,50 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
{ArgDef("gathered_prediction", gathered_prediction_type_proto)},
{ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast<int64_t>(1))},
"GATHERED_LM"));
ONNX_NAMESPACE::TensorProto t_proto;
t_proto.add_dims(2);
t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
t_proto.add_int64_data(static_cast<int64_t>(-1));
t_proto.add_int64_data(prediction_arg->TypeAsProto()->tensor_type().shape().dim()[2].dim_value());
new_nodes.emplace_back(NodeDef("Constant",
{},
{ArgDef("logit_reshape", nullptr)},
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)}));
// Transpose gathered_predictions with the following permutation: {0, 2, 1} because SoftmaxCrossEntropyLoss accepts
// scores of the shape {N, C, D1, D2...Dk}.
TypeProto* gathered_prediction_transposed_type_proto = GetGatheredPredictionTransposedTypeProto(prediction_arg,
graph_defs);
new_nodes.emplace_back(NodeDef("Reshape",
{ArgDef("gathered_prediction", gathered_prediction_type_proto),
ArgDef("logit_reshape")}, // Inputs
{ArgDef("gathered_prediction_reshaped")}, // Outputs
NodeAttributes(),
"Reshape_gathered_prediction"));
new_nodes.emplace_back(NodeDef("Transpose",
{ArgDef("gathered_prediction", gathered_prediction_type_proto)}, // Inputs
{ArgDef("gathered_prediction_transposed",
gathered_prediction_transposed_type_proto)}, // Outputs
{ONNX_NAMESPACE::MakeAttribute("perm", std::vector<int64_t>{static_cast<int64_t>(0),
static_cast<int64_t>(2), static_cast<int64_t>(1)})},
ONNX_NAMESPACE::TensorProto t_proto_label;
t_proto_label.add_dims(1);
t_proto_label.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
t_proto_label.add_int64_data(static_cast<int64_t>(-1));
"Transpose_gathered_prediction"));
new_nodes.emplace_back(NodeDef("Constant",
{},
{ArgDef("label_reshape", nullptr)},
{ONNX_NAMESPACE::MakeAttribute("value", t_proto_label)}));
new_nodes.emplace_back(NodeDef("Reshape",
{ArgDef("masked_lm_ids", masked_lm_int64_type_proto),
ArgDef("label_reshape")}, // Inputs
{ArgDef("masked_lm_ids_reshaped")}, // Outputs
NodeAttributes(),
"Reshape_label"));
std::vector<AttributeProto> attrs;
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("ignore_index", static_cast<int64_t>(-1)));
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("reduction", "mean"));
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
{ArgDef("gathered_prediction_transposed", gathered_prediction_transposed_type_proto),
ArgDef(masked_lm_ids, masked_lm_int64_type_proto)}, // Inputs
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
ArgDef("probability_lm", gathered_prediction_transposed_type_proto)},
{ArgDef("gathered_prediction_reshaped"),
ArgDef("masked_lm_ids_reshaped")}, // Inputs
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
ArgDef("probability_lm")},
attrs,
"Masked_LM_Loss"));
}
@ -167,33 +143,11 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
ONNX_NAMESPACE::TensorProto_DataType_INT64,
graph_defs);
// Transpose prediction_next_sentence with the following permutation: {0, n-1, 1, 2....n-2} because
// SoftmaxCrossEntropyLoss accepts scores of the shape {N, C, D1, D2...Dk}.
TypeProto* prediction_next_sentence_transposed_type_proto = GetTransposedTypeProto(ns_prediction_arg, graph_defs);
const auto* logits_type_proto = ns_prediction_arg->TypeAsProto();
const auto& dims = logits_type_proto->tensor_type().shape().dim();
std::vector<int64_t> prediction_next_sentence_transposed_perm;
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(0));
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(dims.size() - 1));
for (int i = 1; i < dims.size() - 1; i++) {
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(i));
}
new_nodes.emplace_back(NodeDef("Transpose",
{ArgDef(prediction_next_sentence)}, // Inputs
{ArgDef("prediction_next_sentence_transposed",
prediction_next_sentence_transposed_type_proto)}, // Outputs
{ONNX_NAMESPACE::MakeAttribute("perm", prediction_next_sentence_transposed_perm)},
"Transpose_prediction_next_sentence"));
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
{ArgDef("prediction_next_sentence_transposed", prediction_next_sentence_transposed_type_proto),
ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs
{ArgDef(prediction_next_sentence),
ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs
{ArgDef(nsp_loss, GetLossTypeProto(graph_defs)),
ArgDef("probability_ns", prediction_next_sentence_transposed_type_proto)}, // Outputs
ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
"Next_Sentence_Loss"));
}