mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Replace loss function in BERT_LOSS with SoftmaxCrossEntropyLoss. (#4509)
* Replace loss function in BERT_LOSS with SoftmaxCrossEntropyLoss. * Update BERT loss function with correct logit shapes for softmax cross entropy loss. * fix test and PR comments.
This commit is contained in:
parent
76b31d6ce2
commit
b43ce2d7ad
5 changed files with 105 additions and 27 deletions
BIN
onnxruntime/test/testdata/bert_toy_optimized.onnx
vendored
BIN
onnxruntime/test/testdata/bert_toy_optimized.onnx
vendored
Binary file not shown.
|
|
@ -47,6 +47,49 @@ TypeProto* BertLoss::GetGatheredPredictionTypeProto(const NodeArg* prediction_ar
|
|||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLoss::GetTransposedTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs) {
|
||||
ORT_ENFORCE(prediction_arg != nullptr, "GetTransposedTypeProto's prediction_arg is nullptr");
|
||||
const auto* logits_type_proto = prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
|
||||
TypeProto* type_proto = graph_defs.CreateTypeProto();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
||||
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
|
||||
// Batch size.
|
||||
target_shape->add_dim()->CopyFrom(dims[0]);
|
||||
|
||||
// Class.
|
||||
target_shape->add_dim()->CopyFrom(dims[dims.size() - 1]);
|
||||
|
||||
for (int i = 1; i < dims.size() - 1; i++) {
|
||||
target_shape->add_dim()->CopyFrom(dims[i]);
|
||||
}
|
||||
|
||||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLoss::GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs) {
|
||||
ORT_ENFORCE(prediction_arg != nullptr, "GetGatheredPredictionTransposedTypeProto's prediction_arg is nullptr");
|
||||
const auto* logits_type_proto = prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
|
||||
TypeProto* type_proto = graph_defs.CreateTypeProto();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
||||
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
|
||||
// Batch size.
|
||||
target_shape->add_dim()->CopyFrom(dims[0]);
|
||||
// Vocab size.
|
||||
target_shape->add_dim()->CopyFrom(dims[2]);
|
||||
// Prediction count.
|
||||
target_shape->add_dim()->set_dim_param("dynamic_prediction_count");
|
||||
|
||||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) {
|
||||
return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
}
|
||||
|
|
@ -54,15 +97,14 @@ TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) {
|
|||
GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) {
|
||||
const std::string& total_loss = loss_func_info.loss_name;
|
||||
const VectorString& args = loss_func_info.loss_builder_args;
|
||||
ORT_ENFORCE(args.size() == 8, " Invalid loss_func_info for BertLoss.");
|
||||
ORT_ENFORCE(args.size() == 7, " Invalid loss_func_info for BertLoss.");
|
||||
const std::string& prediction_masked_lm = args[0];
|
||||
const std::string& prediction_next_sentence = args[1];
|
||||
const std::string& masked_lm_positions = args[2];
|
||||
const std::string& masked_lm_ids = args[3];
|
||||
const std::string& masked_lm_weights = args[4];
|
||||
const std::string& next_sentence_labels = args[5];
|
||||
const std::string& mlm_loss = args[6];
|
||||
const std::string& nsp_loss = args[7];
|
||||
const std::string& next_sentence_labels = args[4];
|
||||
const std::string& mlm_loss = args[5];
|
||||
const std::string& nsp_loss = args[6];
|
||||
|
||||
std::vector<NodeDef> new_nodes;
|
||||
GraphAugmenter::GraphDefs graph_defs;
|
||||
|
|
@ -89,16 +131,30 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
|
|||
{ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast<int64_t>(1))},
|
||||
"GATHERED_LM"));
|
||||
|
||||
TypeProto* masked_lm_float_type_proto = GetMaskedLMTypeProto(prediction_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
graph_defs);
|
||||
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy",
|
||||
{ArgDef("gathered_prediction", gathered_prediction_type_proto),
|
||||
ArgDef(masked_lm_ids, masked_lm_int64_type_proto),
|
||||
ArgDef(masked_lm_weights, masked_lm_float_type_proto)}, // Inputs
|
||||
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
|
||||
ArgDef("probability_lm", gathered_prediction_type_proto)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
|
||||
// Transpose gathered_predictions with the following permutation: {0, 2, 1} because SoftmaxCrossEntropyLoss accepts
|
||||
// scores of the shape {N, C, D1, D2...Dk}.
|
||||
TypeProto* gathered_prediction_transposed_type_proto = GetGatheredPredictionTransposedTypeProto(prediction_arg,
|
||||
graph_defs);
|
||||
|
||||
new_nodes.emplace_back(NodeDef("Transpose",
|
||||
{ArgDef("gathered_prediction", gathered_prediction_type_proto)}, // Inputs
|
||||
{ArgDef("gathered_prediction_transposed",
|
||||
gathered_prediction_transposed_type_proto)}, // Outputs
|
||||
{ONNX_NAMESPACE::MakeAttribute("perm", std::vector<int64_t>{static_cast<int64_t>(0),
|
||||
static_cast<int64_t>(2), static_cast<int64_t>(1)})},
|
||||
|
||||
"Transpose_gathered_prediction"));
|
||||
|
||||
std::vector<AttributeProto> attrs;
|
||||
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("ignore_index", static_cast<int64_t>(-1)));
|
||||
attrs.push_back(ONNX_NAMESPACE::MakeAttribute("reduction", "mean"));
|
||||
|
||||
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
|
||||
{ArgDef("gathered_prediction_transposed", gathered_prediction_transposed_type_proto),
|
||||
ArgDef(masked_lm_ids, masked_lm_int64_type_proto)}, // Inputs
|
||||
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
|
||||
ArgDef("probability_lm", gathered_prediction_transposed_type_proto)},
|
||||
attrs,
|
||||
"Masked_LM_Loss"));
|
||||
}
|
||||
|
||||
|
|
@ -111,11 +167,33 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
|
|||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
graph_defs);
|
||||
|
||||
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy",
|
||||
{ArgDef(prediction_next_sentence),
|
||||
// Transpose prediction_next_sentence with the following permutation: {0, n-1, 1, 2....n-2} because
|
||||
// SoftmaxCrossEntropyLoss accepts scores of the shape {N, C, D1, D2...Dk}.
|
||||
|
||||
TypeProto* prediction_next_sentence_transposed_type_proto = GetTransposedTypeProto(ns_prediction_arg, graph_defs);
|
||||
const auto* logits_type_proto = ns_prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
std::vector<int64_t> prediction_next_sentence_transposed_perm;
|
||||
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(0));
|
||||
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(dims.size() - 1));
|
||||
|
||||
for (int i = 1; i < dims.size() - 1; i++) {
|
||||
prediction_next_sentence_transposed_perm.emplace_back(static_cast<int64_t>(i));
|
||||
}
|
||||
|
||||
new_nodes.emplace_back(NodeDef("Transpose",
|
||||
{ArgDef(prediction_next_sentence)}, // Inputs
|
||||
{ArgDef("prediction_next_sentence_transposed",
|
||||
prediction_next_sentence_transposed_type_proto)}, // Outputs
|
||||
{ONNX_NAMESPACE::MakeAttribute("perm", prediction_next_sentence_transposed_perm)},
|
||||
|
||||
"Transpose_prediction_next_sentence"));
|
||||
|
||||
new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss",
|
||||
{ArgDef("prediction_next_sentence_transposed", prediction_next_sentence_transposed_type_proto),
|
||||
ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs
|
||||
{ArgDef(nsp_loss, GetLossTypeProto(graph_defs)),
|
||||
ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs
|
||||
ArgDef("probability_ns", prediction_next_sentence_transposed_type_proto)}, // Outputs
|
||||
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
|
||||
"Next_Sentence_Loss"));
|
||||
}
|
||||
|
|
@ -136,7 +214,7 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun
|
|||
}
|
||||
|
||||
graph_defs.AddNodeDefs(new_nodes);
|
||||
graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels});
|
||||
graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, next_sentence_labels});
|
||||
graph_defs.AddGraphOutputs({mlm_loss, nsp_loss, total_loss});
|
||||
|
||||
return graph_defs;
|
||||
|
|
|
|||
|
|
@ -15,8 +15,16 @@ struct BertLoss : public ILossFunction {
|
|||
static TypeProto* GetMaskedLMTypeProto(const NodeArg* prediction_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType data_type,
|
||||
GraphAugmenter::GraphDefs& graph_defs);
|
||||
|
||||
static TypeProto* GetGatheredPredictionTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs);
|
||||
|
||||
static TypeProto* GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs);
|
||||
|
||||
static TypeProto* GetTransposedTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs);
|
||||
|
||||
static TypeProto* GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -566,7 +566,6 @@ void setup_training_params(BertParameters& params) {
|
|||
/*prediction_next_sentence*/ "output2",
|
||||
/*masked_lm_positions*/ "masked_lm_positions",
|
||||
/*masked_lm_ids*/ "masked_lm_ids",
|
||||
/*masked_lm_weights*/ "masked_lm_weights",
|
||||
/*next_sentence_labels*/ "next_sentence_labels",
|
||||
/*mlm_loss*/ "mlm_loss",
|
||||
/*nsp_loss*/ "nsp_loss"});
|
||||
|
|
@ -593,7 +592,6 @@ void setup_training_params(BertParameters& params) {
|
|||
{"input_mask", "input3"},
|
||||
{"masked_lm_positions", "masked_lm_positions"},
|
||||
{"masked_lm_ids", "masked_lm_ids"},
|
||||
{"masked_lm_weights", "masked_lm_weights"},
|
||||
{"next_sentence_label", "next_sentence_labels"}};
|
||||
|
||||
params.model_type = "bert";
|
||||
|
|
@ -696,7 +694,6 @@ static Status RunPerformanceTest(const BertParameters& params, const Environment
|
|||
"input3", /*input_mask*/
|
||||
"masked_lm_positions",
|
||||
"masked_lm_ids",
|
||||
"masked_lm_weights",
|
||||
"next_sentence_labels"};
|
||||
std::vector<TensorShape> tensor_shapes = {{batch_size, params.max_sequence_length},
|
||||
{batch_size, params.max_sequence_length},
|
||||
|
|
|
|||
|
|
@ -330,7 +330,6 @@ static void RunBertTrainingWithChecks(
|
|||
"masked_lm_ids",
|
||||
"next_sentence_labels",
|
||||
"masked_lm_positions",
|
||||
"masked_lm_weights",
|
||||
};
|
||||
std::vector<TensorShape> tensor_shapes = {
|
||||
{batch_size, max_seq_len_in_batch},
|
||||
|
|
@ -414,13 +413,11 @@ static void RunBertTrainingWithChecks(
|
|||
0, 1, 2, 3, 4, 5, 6,
|
||||
0, 1, 2, 3, 4, 5, 6,
|
||||
0, 1, 2, 3, 4, 5, 6}};
|
||||
std::vector<float> masked_lm_weights(13 * 7, 1.0f);
|
||||
|
||||
std::vector<OrtValue> feeds(feed_names.size());
|
||||
for (size_t i = 0; i < 6; ++i) {
|
||||
TrainingUtil::CreateCpuMLValue(tensor_shapes[i].GetDims(), tensor_values[i], &feeds[i]);
|
||||
}
|
||||
TrainingUtil::CreateCpuMLValue(tensor_shapes[6].GetDims(), masked_lm_weights, &feeds[6]);
|
||||
|
||||
auto output_names_include_gradients = GetModelOutputNames(*training_session);
|
||||
std::vector<std::string> fetch_names(output_names_include_gradients.begin(), output_names_include_gradients.end());
|
||||
|
|
@ -475,7 +472,6 @@ TEST(GradientGraphBuilderTest, TrainingSession_BertToy) {
|
|||
/*prediction_next_sentence*/ "seq_relationship_score",
|
||||
/*masked_lm_positions*/ "masked_lm_positions",
|
||||
/*masked_lm_ids*/ "masked_lm_ids",
|
||||
/*masked_lm_weights*/ "masked_lm_weights",
|
||||
/*next_sentence_labels*/ "next_sentence_labels",
|
||||
/*mlm_loss*/ "mlm_loss",
|
||||
/*nsp_loss*/ "nsp_loss"});
|
||||
|
|
@ -1163,7 +1159,6 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
|
|||
/*prediction_next_sentence*/ "seq_relationship_score",
|
||||
/*masked_lm_positions*/ "masked_lm_positions",
|
||||
/*masked_lm_ids*/ "masked_lm_ids",
|
||||
/*masked_lm_weights*/ "masked_lm_weights",
|
||||
/*next_sentence_labels*/ "next_sentence_labels",
|
||||
/*mlm_loss*/ "mlm_loss",
|
||||
/*nsp_loss*/ "nsp_loss"});
|
||||
|
|
|
|||
Loading…
Reference in a new issue