diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc index 02f6fc933a..3b8b60ab56 100644 --- a/orttraining/orttraining/core/graph/loss_func/bert_loss.cc +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss.cc @@ -47,49 +47,6 @@ TypeProto* BertLoss::GetGatheredPredictionTypeProto(const NodeArg* prediction_ar return type_proto; } -TypeProto* BertLoss::GetTransposedTypeProto(const NodeArg* prediction_arg, - GraphAugmenter::GraphDefs& graph_defs) { - ORT_ENFORCE(prediction_arg != nullptr, "GetTransposedTypeProto's prediction_arg is nullptr"); - const auto* logits_type_proto = prediction_arg->TypeAsProto(); - const auto& dims = logits_type_proto->tensor_type().shape().dim(); - - TypeProto* type_proto = graph_defs.CreateTypeProto(); - type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - - auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape(); - // Batch size. - target_shape->add_dim()->CopyFrom(dims[0]); - - // Class. - target_shape->add_dim()->CopyFrom(dims[dims.size() - 1]); - - for (int i = 1; i < dims.size() - 1; i++) { - target_shape->add_dim()->CopyFrom(dims[i]); - } - - return type_proto; -} - -TypeProto* BertLoss::GetGatheredPredictionTransposedTypeProto(const NodeArg* prediction_arg, - GraphAugmenter::GraphDefs& graph_defs) { - ORT_ENFORCE(prediction_arg != nullptr, "GetGatheredPredictionTransposedTypeProto's prediction_arg is nullptr"); - const auto* logits_type_proto = prediction_arg->TypeAsProto(); - const auto& dims = logits_type_proto->tensor_type().shape().dim(); - - TypeProto* type_proto = graph_defs.CreateTypeProto(); - type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - - auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape(); - // Batch size. - target_shape->add_dim()->CopyFrom(dims[0]); - // Vocab size. - target_shape->add_dim()->CopyFrom(dims[2]); - // Prediction count. - target_shape->add_dim()->set_dim_param("dynamic_prediction_count"); - - return type_proto; -} - TypeProto* BertLoss::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) { return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT); } @@ -108,7 +65,6 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun std::vector new_nodes; GraphAugmenter::GraphDefs graph_defs; - // LabelSoftmaxCrossEntropy for masked_lm { const NodeArg* prediction_arg = graph.GetNodeArg(prediction_masked_lm); @@ -130,30 +86,50 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun {ArgDef("gathered_prediction", gathered_prediction_type_proto)}, {ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast(1))}, "GATHERED_LM")); + + ONNX_NAMESPACE::TensorProto t_proto; + t_proto.add_dims(2); + t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + t_proto.add_int64_data(static_cast(-1)); + t_proto.add_int64_data(prediction_arg->TypeAsProto()->tensor_type().shape().dim()[2].dim_value()); + new_nodes.emplace_back(NodeDef("Constant", + {}, + {ArgDef("logit_reshape", nullptr)}, + {ONNX_NAMESPACE::MakeAttribute("value", t_proto)})); - // Transpose gathered_predictions with the following permutation: {0, 2, 1} because SoftmaxCrossEntropyLoss accepts - // scores of the shape {N, C, D1, D2...Dk}. - TypeProto* gathered_prediction_transposed_type_proto = GetGatheredPredictionTransposedTypeProto(prediction_arg, - graph_defs); + new_nodes.emplace_back(NodeDef("Reshape", + {ArgDef("gathered_prediction", gathered_prediction_type_proto), + ArgDef("logit_reshape")}, // Inputs + {ArgDef("gathered_prediction_reshaped")}, // Outputs + NodeAttributes(), + "Reshape_gathered_prediction")); - new_nodes.emplace_back(NodeDef("Transpose", - {ArgDef("gathered_prediction", gathered_prediction_type_proto)}, // Inputs - {ArgDef("gathered_prediction_transposed", - gathered_prediction_transposed_type_proto)}, // Outputs - {ONNX_NAMESPACE::MakeAttribute("perm", std::vector{static_cast(0), - static_cast(2), static_cast(1)})}, + ONNX_NAMESPACE::TensorProto t_proto_label; + t_proto_label.add_dims(1); + t_proto_label.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + t_proto_label.add_int64_data(static_cast(-1)); - "Transpose_gathered_prediction")); + new_nodes.emplace_back(NodeDef("Constant", + {}, + {ArgDef("label_reshape", nullptr)}, + {ONNX_NAMESPACE::MakeAttribute("value", t_proto_label)})); + + new_nodes.emplace_back(NodeDef("Reshape", + {ArgDef("masked_lm_ids", masked_lm_int64_type_proto), + ArgDef("label_reshape")}, // Inputs + {ArgDef("masked_lm_ids_reshaped")}, // Outputs + NodeAttributes(), + "Reshape_label")); std::vector attrs; attrs.push_back(ONNX_NAMESPACE::MakeAttribute("ignore_index", static_cast(-1))); attrs.push_back(ONNX_NAMESPACE::MakeAttribute("reduction", "mean")); new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", - {ArgDef("gathered_prediction_transposed", gathered_prediction_transposed_type_proto), - ArgDef(masked_lm_ids, masked_lm_int64_type_proto)}, // Inputs - {ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs - ArgDef("probability_lm", gathered_prediction_transposed_type_proto)}, + {ArgDef("gathered_prediction_reshaped"), + ArgDef("masked_lm_ids_reshaped")}, // Inputs + {ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs + ArgDef("probability_lm")}, attrs, "Masked_LM_Loss")); } @@ -167,33 +143,11 @@ GraphAugmenter::GraphDefs BertLoss::operator()(const Graph& graph, const LossFun ONNX_NAMESPACE::TensorProto_DataType_INT64, graph_defs); - // Transpose prediction_next_sentence with the following permutation: {0, n-1, 1, 2....n-2} because - // SoftmaxCrossEntropyLoss accepts scores of the shape {N, C, D1, D2...Dk}. - - TypeProto* prediction_next_sentence_transposed_type_proto = GetTransposedTypeProto(ns_prediction_arg, graph_defs); - const auto* logits_type_proto = ns_prediction_arg->TypeAsProto(); - const auto& dims = logits_type_proto->tensor_type().shape().dim(); - std::vector prediction_next_sentence_transposed_perm; - prediction_next_sentence_transposed_perm.emplace_back(static_cast(0)); - prediction_next_sentence_transposed_perm.emplace_back(static_cast(dims.size() - 1)); - - for (int i = 1; i < dims.size() - 1; i++) { - prediction_next_sentence_transposed_perm.emplace_back(static_cast(i)); - } - - new_nodes.emplace_back(NodeDef("Transpose", - {ArgDef(prediction_next_sentence)}, // Inputs - {ArgDef("prediction_next_sentence_transposed", - prediction_next_sentence_transposed_type_proto)}, // Outputs - {ONNX_NAMESPACE::MakeAttribute("perm", prediction_next_sentence_transposed_perm)}, - - "Transpose_prediction_next_sentence")); - new_nodes.emplace_back(NodeDef("SoftmaxCrossEntropyLoss", - {ArgDef("prediction_next_sentence_transposed", prediction_next_sentence_transposed_type_proto), - ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs + {ArgDef(prediction_next_sentence), + ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs {ArgDef(nsp_loss, GetLossTypeProto(graph_defs)), - ArgDef("probability_ns", prediction_next_sentence_transposed_type_proto)}, // Outputs + ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs {ONNX_NAMESPACE::MakeAttribute("reduction", "mean")}, "Next_Sentence_Loss")); }