diff --git a/onnxruntime/test/testdata/bert_toy_optimized.onnx b/onnxruntime/test/testdata/bert_toy_optimized.onnx index 0b99357470..f2d6b241bd 100644 Binary files a/onnxruntime/test/testdata/bert_toy_optimized.onnx and b/onnxruntime/test/testdata/bert_toy_optimized.onnx differ diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc index b3237c2405..02f6fc933a 100644 --- a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc @@ -47,6 +47,49 @@ TypeProto* BertLoss::GetGatheredPredictionTypeProto(const NodeArg* prediction_ar return type_proto; } +TypeProto* BertLoss::GetTransposedTypeProto(const NodeArg* prediction_arg, + GraphAugmenter::GraphDefs& graph_defs) { + ORT_ENFORCE(prediction_arg != nullptr, "GetTransposedTypeProto's prediction_arg is nullptr"); + const auto* logits_type_proto = prediction_arg->TypeAsProto(); + const auto& dims = logits_type_proto->tensor_type().shape().dim(); + + TypeProto* type_proto = graph_defs.CreateTypeProto(); + type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape(); + // Batch size. + target_shape->add_dim()->CopyFrom(dims[0]); + + // Class. + target_shape->add_dim()->CopyFrom(dims[dims.size() - 1]); + + for (int i = 1; i < dims.size() - 1; i++) { + target_shape->add_dim()->CopyFrom(dims[i]); + } + + return type_proto; +} + +TypeProto* BertLoss::GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg, + GraphAugmenter::GraphDefs& graph_defs) { + ORT_ENFORCE(prediction_arg != nullptr, "GetGatheredPredictionTransposedTypeProto's prediction_arg is nullptr"); + const auto* logits_type_proto = prediction_arg->TypeAsProto(); + const auto& dims = logits_type_proto->tensor_type().shape().dim(); + + TypeProto* type_proto = graph_defs.CreateTypeProto(); + type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape(); + // Batch size. + target_shape->add_dim()->CopyFrom(dims[0]); + // Vocab size. + target_shape->add_dim()->CopyFrom(dims[2]); + // Prediction count. + target_shape->add_dim()->set_dim_param("dynamic_prediction_count"); + + return type_proto; +} + TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) { return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT); } @@ -54,15 +97,14 @@ TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) { GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) { const std::string& total_loss = loss_func_info.loss_name; const VectorString& args = loss_func_info.loss_builder_args; - ORT_ENFORCE(args.size() == 8, " Invalid loss_func_info for BertLoss."); + ORT_ENFORCE(args.size() == 7, " Invalid loss_func_info for BertLoss."); const std::string& prediction_masked_lm = args[0]; const std::string& prediction_next_sentence = args[1]; const std::string& masked_lm_positions = args[2]; const std::string& masked_lm_ids = args[3]; - const std::string& masked_lm_weights = args[4]; - const std::string& next_sentence_labels = args[5]; - const std::string& mlm_loss = args[6]; - const std::string& nsp_loss = args[7]; + const std::string& next_sentence_labels = args[4]; + const std::string& mlm_loss = args[5]; + const std::string& nsp_loss = args[6]; std::vector new_nodes; GraphAugmenter::GraphDefs graph_defs; @@ -89,16 +131,30 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun {ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast(1))}, "GATHERED_LM")); - TypeProto* masked_lm_float_type_proto = GetMaskedLMTypeProto(prediction_arg, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - graph_defs); - new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", - {ArgDef("gathered_prediction", gathered_prediction_type_proto), - ArgDef(masked_lm_ids, masked_lm_int64_type_proto), - ArgDef(masked_lm_weights, masked_lm_float_type_proto)}, // Inputs - {ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs - ArgDef("probability_lm", gathered_prediction_type_proto)}, - {ONNX_NAMESPACE::MakeAttribute("reduction", "mean")}, + // Transpose gathered_predictions with the following permutation: {0, 2, 1} because SoftmaxCrossEntropyLoss accepts + // scores of the shape {N, C, D1, D2...Dk}. + TypeProto* gathered_prediction_transposed_type_proto = GetGatheredPredictionTransposedTypeProto(prediction_arg, + graph_defs); + + new_nodes.emplace_back(NodeDef("Transpose", + {ArgDef("gathered_prediction", gathered_prediction_type_proto)}, // Inputs + {ArgDef("gathered_prediction_transposed", + gathered_prediction_transposed_type_proto)}, // Outputs + {ONNX_NAMESPACE::MakeAttribute("perm", std::vector{static_cast(0), + static_cast(2), static_cast(1)})}, + + "Transpose_gathered_prediction")); + + std::vector attrs; + attrs.push_back(ONNX_NAMESPACE::MakeAttribute("ignore_index", static_cast(-1))); + attrs.push_back(ONNX_NAMESPACE::MakeAttribute("reduction", "mean")); + + new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", + {ArgDef("gathered_prediction_transposed", gathered_prediction_transposed_type_proto), + ArgDef(masked_lm_ids, masked_lm_int64_type_proto)}, // Inputs + {ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs + ArgDef("probability_lm", gathered_prediction_transposed_type_proto)}, + attrs, "Masked_LM_Loss")); } @@ -111,11 +167,33 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun ONNX_NAMESPACE::TensorProto_DataType_INT64, graph_defs); - new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", - {ArgDef(prediction_next_sentence), + // Transpose prediction_next_sentence with the following permutation: {0, n-1, 1, 2....n-2} because + // SoftmaxCrossEntropyLoss accepts scores of the shape {N, C, D1, D2...Dk}. + + TypeProto* prediction_next_sentence_transposed_type_proto = GetTransposedTypeProto(ns_prediction_arg, graph_defs); + const auto* logits_type_proto = ns_prediction_arg->TypeAsProto(); + const auto& dims = logits_type_proto->tensor_type().shape().dim(); + std::vector prediction_next_sentence_transposed_perm; + prediction_next_sentence_transposed_perm.emplace_back(static_cast(0)); + prediction_next_sentence_transposed_perm.emplace_back(static_cast(dims.size() - 1)); + + for (int i = 1; i < dims.size() - 1; i++) { + prediction_next_sentence_transposed_perm.emplace_back(static_cast(i)); + } + + new_nodes.emplace_back(NodeDef("Transpose", + {ArgDef(prediction_next_sentence)}, // Inputs + {ArgDef("prediction_next_sentence_transposed", + prediction_next_sentence_transposed_type_proto)}, // Outputs + {ONNX_NAMESPACE::MakeAttribute("perm", prediction_next_sentence_transposed_perm)}, + + "Transpose_prediction_next_sentence")); + + new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", + {ArgDef("prediction_next_sentence_transposed", prediction_next_sentence_transposed_type_proto), ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs {ArgDef(nsp_loss, GetLossTypeProto(graph_defs)), - ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs + ArgDef("probability_ns", prediction_next_sentence_transposed_type_proto)}, // Outputs {ONNX_NAMESPACE::MakeAttribute("reduction", "mean")}, "Next_Sentence_Loss")); } @@ -136,7 +214,7 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun } graph_defs.AddNodeDefs(new_nodes); - graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels}); + graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, next_sentence_labels}); graph_defs.AddGraphOutputs({mlm_loss, nsp_loss, total_loss}); return graph_defs; diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss.h b/orttraining/orttraining/core/graph/loss_func/bert_loss.h index 787564baac..2a76a46b3a 100644 --- a/orttraining/orttraining/core/graph/loss_func/bert_loss.h +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss.h @@ -15,8 +15,16 @@ struct BertLoss : public ILossFunction { static TypeProto* GetMaskedLMTypeProto(const NodeArg* prediction_arg, ONNX_NAMESPACE::TensorProto_DataType data_type, GraphAugmenter::GraphDefs& graph_defs); + static TypeProto* GetGatheredPredictionTypeProto(const NodeArg* prediction_arg, GraphAugmenter::GraphDefs& graph_defs); + + static TypeProto* GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg, + GraphAugmenter::GraphDefs& graph_defs); + + static TypeProto* GetTransposedTypeProto(const NodeArg* prediction_arg, + GraphAugmenter::GraphDefs& graph_defs); + static TypeProto* GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs); }; diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index bbde1866b9..752211e640 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -566,7 +566,6 @@ void setup_training_params(BertParameters& params) { /*prediction_next_sentence*/ "output2", /*masked_lm_positions*/ "masked_lm_positions", /*masked_lm_ids*/ "masked_lm_ids", - /*masked_lm_weights*/ "masked_lm_weights", /*next_sentence_labels*/ "next_sentence_labels", /*mlm_loss*/ "mlm_loss", /*nsp_loss*/ "nsp_loss"}); @@ -593,7 +592,6 @@ void setup_training_params(BertParameters& params) { {"input_mask", "input3"}, {"masked_lm_positions", "masked_lm_positions"}, {"masked_lm_ids", "masked_lm_ids"}, - {"masked_lm_weights", "masked_lm_weights"}, {"next_sentence_label", "next_sentence_labels"}}; params.model_type = "bert"; @@ -696,7 +694,6 @@ static Status RunPerformanceTest(const BertParameters& params, const Environment "input3", /*input_mask*/ "masked_lm_positions", "masked_lm_ids", - "masked_lm_weights", "next_sentence_labels"}; std::vector tensor_shapes = {{batch_size, params.max_sequence_length}, {batch_size, params.max_sequence_length}, diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index d8810ab73a..571152b107 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -330,7 +330,6 @@ static void RunBertTrainingWithChecks( "masked_lm_ids", "next_sentence_labels", "masked_lm_positions", - "masked_lm_weights", }; std::vector tensor_shapes = { {batch_size, max_seq_len_in_batch}, @@ -414,13 +413,11 @@ static void RunBertTrainingWithChecks( 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6}}; - std::vector masked_lm_weights(13 * 7, 1.0f); std::vector feeds(feed_names.size()); for (size_t i = 0; i < 6; ++i) { TrainingUtil::CreateCpuMLValue(tensor_shapes[i].GetDims(), tensor_values[i], &feeds[i]); } - TrainingUtil::CreateCpuMLValue(tensor_shapes[6].GetDims(), masked_lm_weights, &feeds[6]); auto output_names_include_gradients = GetModelOutputNames(*training_session); std::vector fetch_names(output_names_include_gradients.begin(), output_names_include_gradients.end()); @@ -475,7 +472,6 @@ TEST(GradientGraphBuilderTest, TrainingSession_BertToy) { /*prediction_next_sentence*/ "seq_relationship_score", /*masked_lm_positions*/ "masked_lm_positions", /*masked_lm_ids*/ "masked_lm_ids", - /*masked_lm_weights*/ "masked_lm_weights", /*next_sentence_labels*/ "next_sentence_labels", /*mlm_loss*/ "mlm_loss", /*nsp_loss*/ "nsp_loss"}); @@ -1163,7 +1159,6 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) { /*prediction_next_sentence*/ "seq_relationship_score", /*masked_lm_positions*/ "masked_lm_positions", /*masked_lm_ids*/ "masked_lm_ids", - /*masked_lm_weights*/ "masked_lm_weights", /*next_sentence_labels*/ "next_sentence_labels", /*mlm_loss*/ "mlm_loss", /*nsp_loss*/ "nsp_loss"});