Replace loss function in BERT_LOSS with SoftmaxCrossEntropyLoss. (#4509)

* Replace loss function in BERT_LOSS with SoftmaxCrossEntropyLoss.

* Update BERT loss function with correct logit shapes for softmax cross entropy loss.

* fix test and PR comments.
This commit is contained in:
M. Zeeshan Siddiqui 2020-07-16 15:28:24 -07:00 committed by GitHub
parent 76b31d6ce2
commit b43ce2d7ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 105 additions and 27 deletions

Binary file not shown.

View file

@ -47,6 +47,49 @@ TypeProto* BertLoss::GetGatheredPredictionTypeProto(const NodeArg* prediction_ar
return type_proto;
}
TypeProto* BertLoss::GetTransposedTypeProto(const NodeArg* prediction_arg,
GraphAugmenter::GraphDefs& graph_defs) {
ORT_ENFORCE(prediction_arg != nullptr, "GetTransposedTypeProto's prediction_arg is nullptr");
const auto* logits_type_proto = prediction_arg->TypeAsProto();
const auto& dims = logits_type_proto->tensor_type().shape().dim();
TypeProto* type_proto = graph_defs.CreateTypeProto();
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
// Batch size.
target_shape->add_dim()->CopyFrom(dims[0]);
// Class.
target_shape->add_dim()->CopyFrom(dims[dims.size() - 1]);
for (int i = 1; i < dims.size() - 1; i++) {
target_shape->add_dim()->CopyFrom(dims[i]);
}
return type_proto;
}
TypeProto* BertLoss::GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg,
GraphAugmenter::GraphDefs& graph_defs) {
ORT_ENFORCE(prediction_arg != nullptr, "GetGatheredPredictionTransposedTypeProto's prediction_arg is nullptr");
const auto* logits_type_proto = prediction_arg->TypeAsProto();
const auto& dims = logits_type_proto->tensor_type().shape().dim();
TypeProto* type_proto = graph_defs.CreateTypeProto();
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
// Batch size.
target_shape->add_dim()->CopyFrom(dims[0]);
// Vocab size.
target_shape->add_dim()->CopyFrom(dims[2]);
// Prediction count.
target_shape->add_dim()->set_dim_param("dynamic_prediction_count");
return type_proto;
}
TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) {
return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
}
@ -54,15 +97,14 @@ TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) {
GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) {
const std::string& total_loss = loss_func_info.loss_name;
const VectorString& args = loss_func_info.loss_builder_args;
ORT_ENFORCE(args.size() == 8, " Invalid loss_func_info for BertLoss.");
ORT_ENFORCE(args.size() == 7, " Invalid loss_func_info for BertLoss.");
const std::string& prediction_masked_lm = args[0];
const std::string& prediction_next_sentence = args[1];
const std::string& masked_lm_positions = args[2];
const std::string& masked_lm_ids = args[3];
const std::string& masked_lm_weights = args[4];
const std::string& next_sentence_labels = args[5];
const std::string& mlm_loss = args[6];
const std::string& nsp_loss = args[7];
const std::string& next_sentence_labels = args[4];
const std::string& mlm_loss = args[5];
const std::string& nsp_loss = args[6];
std::vector<NodeDef> new_nodes;
GraphAugmenter::GraphDefs graph_defs;
@ -89,16 +131,30 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
{ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast<int64_t>(1))},
"GATHERED_LM"));
TypeProto* masked_lm_float_type_proto = GetMaskedLMTypeProto(prediction_arg,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
graph_defs);
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy",
{ArgDef("gathered_prediction", gathered_prediction_type_proto),
ArgDef(masked_lm_ids, masked_lm_int64_type_proto),
ArgDef(masked_lm_weights, masked_lm_float_type_proto)}, // Inputs
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
ArgDef("probability_lm", gathered_prediction_type_proto)},
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
// Transpose gathered_predictions with the following permutation: {0, 2, 1} because SoftmaxCrossEntropyLoss accepts
// scores of the shape {N, C, D1, D2...Dk}.
TypeProto* gathered_prediction_transposed_type_proto = GetGatheredPredictionTransposedTypeProto(prediction_arg,
graph_defs);
new_nodes.emplace_back(NodeDef("Transpose",
{ArgDef("gathered_prediction", gathered_prediction_type_proto)}, // Inputs
{ArgDef("gathered_prediction_transposed",
gathered_prediction_transposed_type_proto)}, // Outputs
{ONNX_NAMESPACE::MakeAttribute("perm", std::vector<int64_t>{static_cast<int64_t>(0),
static_cast<int64_t>(2), static_cast<int64_t>(1)})},
"Transpose_gathered_prediction"));
std::vector<AttributeProto> attrs;
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("ignore_index", static_cast<int64_t>(-1)));
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("reduction", "mean"));
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
{ArgDef("gathered_prediction_transposed", gathered_prediction_transposed_type_proto),
ArgDef(masked_lm_ids, masked_lm_int64_type_proto)}, // Inputs
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
ArgDef("probability_lm", gathered_prediction_transposed_type_proto)},
attrs,
"Masked_LM_Loss"));
}
@ -111,11 +167,33 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
ONNX_NAMESPACE::TensorProto_DataType_INT64,
graph_defs);
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy",
{ArgDef(prediction_next_sentence),
// Transpose prediction_next_sentence with the following permutation: {0, n-1, 1, 2....n-2} because
// SoftmaxCrossEntropyLoss accepts scores of the shape {N, C, D1, D2...Dk}.
TypeProto* prediction_next_sentence_transposed_type_proto = GetTransposedTypeProto(ns_prediction_arg, graph_defs);
const auto* logits_type_proto = ns_prediction_arg->TypeAsProto();
const auto& dims = logits_type_proto->tensor_type().shape().dim();
std::vector<int64_t> prediction_next_sentence_transposed_perm;
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(0));
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(dims.size() - 1));
for (int i = 1; i < dims.size() - 1; i++) {
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(i));
}
new_nodes.emplace_back(NodeDef("Transpose",
{ArgDef(prediction_next_sentence)}, // Inputs
{ArgDef("prediction_next_sentence_transposed",
prediction_next_sentence_transposed_type_proto)}, // Outputs
{ONNX_NAMESPACE::MakeAttribute("perm", prediction_next_sentence_transposed_perm)},
"Transpose_prediction_next_sentence"));
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
{ArgDef("prediction_next_sentence_transposed", prediction_next_sentence_transposed_type_proto),
ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs
{ArgDef(nsp_loss, GetLossTypeProto(graph_defs)),
ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs
ArgDef("probability_ns", prediction_next_sentence_transposed_type_proto)}, // Outputs
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
"Next_Sentence_Loss"));
}
@ -136,7 +214,7 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
}
graph_defs.AddNodeDefs(new_nodes);
graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels});
graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, next_sentence_labels});
graph_defs.AddGraphOutputs({mlm_loss, nsp_loss, total_loss});
return graph_defs;

View file

@ -15,8 +15,16 @@ struct BertLoss : public ILossFunction {
static TypeProto* GetMaskedLMTypeProto(const NodeArg* prediction_arg,
ONNX_NAMESPACE::TensorProto_DataType data_type,
GraphAugmenter::GraphDefs& graph_defs);
static TypeProto* GetGatheredPredictionTypeProto(const NodeArg* prediction_arg,
GraphAugmenter::GraphDefs& graph_defs);
static TypeProto* GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg,
GraphAugmenter::GraphDefs& graph_defs);
static TypeProto* GetTransposedTypeProto(const NodeArg* prediction_arg,
GraphAugmenter::GraphDefs& graph_defs);
static TypeProto* GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs);
};

View file

@ -566,7 +566,6 @@ void setup_training_params(BertParameters& params) {
/*prediction_next_sentence*/ "output2",
/*masked_lm_positions*/ "masked_lm_positions",
/*masked_lm_ids*/ "masked_lm_ids",
/*masked_lm_weights*/ "masked_lm_weights",
/*next_sentence_labels*/ "next_sentence_labels",
/*mlm_loss*/ "mlm_loss",
/*nsp_loss*/ "nsp_loss"});
@ -593,7 +592,6 @@ void setup_training_params(BertParameters& params) {
{"input_mask", "input3"},
{"masked_lm_positions", "masked_lm_positions"},
{"masked_lm_ids", "masked_lm_ids"},
{"masked_lm_weights", "masked_lm_weights"},
{"next_sentence_label", "next_sentence_labels"}};
params.model_type = "bert";
@ -696,7 +694,6 @@ static Status RunPerformanceTest(const BertParameters& params, const Environment
"input3", /*input_mask*/
"masked_lm_positions",
"masked_lm_ids",
"masked_lm_weights",
"next_sentence_labels"};
std::vector<TensorShape> tensor_shapes = {{batch_size, params.max_sequence_length},
{batch_size, params.max_sequence_length},

View file

@ -330,7 +330,6 @@ static void RunBertTrainingWithChecks(
"masked_lm_ids",
"next_sentence_labels",
"masked_lm_positions",
"masked_lm_weights",
};
std::vector<TensorShape> tensor_shapes = {
{batch_size, max_seq_len_in_batch},
@ -414,13 +413,11 @@ static void RunBertTrainingWithChecks(
0, 1, 2, 3, 4, 5, 6,
0, 1, 2, 3, 4, 5, 6,
0, 1, 2, 3, 4, 5, 6}};
std::vector<float> masked_lm_weights(13 * 7, 1.0f);
std::vector<OrtValue> feeds(feed_names.size());
for (size_t i = 0; i < 6; ++i) {
TrainingUtil::CreateCpuMLValue(tensor_shapes[i].GetDims(), tensor_values[i], &feeds[i]);
}
TrainingUtil::CreateCpuMLValue(tensor_shapes[6].GetDims(), masked_lm_weights, &feeds[6]);
auto output_names_include_gradients = GetModelOutputNames(*training_session);
std::vector<std::string> fetch_names(output_names_include_gradients.begin(), output_names_include_gradients.end());
@ -475,7 +472,6 @@ TEST(GradientGraphBuilderTest, TrainingSession_BertToy) {
/*prediction_next_sentence*/ "seq_relationship_score",
/*masked_lm_positions*/ "masked_lm_positions",
/*masked_lm_ids*/ "masked_lm_ids",
/*masked_lm_weights*/ "masked_lm_weights",
/*next_sentence_labels*/ "next_sentence_labels",
/*mlm_loss*/ "mlm_loss",
/*nsp_loss*/ "nsp_loss"});
@ -1163,7 +1159,6 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
/*prediction_next_sentence*/ "seq_relationship_score",
/*masked_lm_positions*/ "masked_lm_positions",
/*masked_lm_ids*/ "masked_lm_ids",
/*masked_lm_weights*/ "masked_lm_weights",
/*next_sentence_labels*/ "next_sentence_labels",
/*mlm_loss*/ "mlm_loss",
/*nsp_loss*/ "nsp_loss"});