From d9ecc0cebf8752893d4cd4e547341e390dc29b01 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Sun, 27 Sep 2020 13:41:16 -0700 Subject: [PATCH] add bert loss legacy back (#5224) --- .../core/graph/loss_func/bert_loss_legacy.cc | 145 ++++++++++++++++++ .../core/graph/loss_func/bert_loss_legacy.h | 24 +++ .../core/graph/loss_function_registry.cc | 2 + 3 files changed, 171 insertions(+) create mode 100644 orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc create mode 100644 orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.h diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc new file mode 100644 index 0000000000..90f124b301 --- /dev/null +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "orttraining/core/graph/loss_func/bert_loss_legacy.h" +#include "onnx/defs/attr_proto_util.h" + +namespace onnxruntime { +namespace training { + +TypeProto* BertLossLegacy::GetMaskedLMTypeProto(const NodeArg* prediction_arg, + ONNX_NAMESPACE::TensorProto_DataType data_type, + GraphAugmenter::GraphDefs& graph_defs) { + ORT_ENFORCE(prediction_arg != nullptr, "GetMaskedPredictionTypeProto's prediction_arg is nullptr"); + const auto* logits_type_proto = prediction_arg->TypeAsProto(); + const auto& dims = logits_type_proto->tensor_type().shape().dim(); + + TypeProto* type_proto = graph_defs.CreateTypeProto(); + type_proto->mutable_tensor_type()->set_elem_type(data_type); + + auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape(); + // Batch size. + target_shape->add_dim()->CopyFrom(dims[0]); + // Prediction count. + target_shape->add_dim()->set_dim_param("dynamic_prediction_count"); + + return type_proto; +} + +TypeProto* BertLossLegacy::GetGatheredPredictionTypeProto(const NodeArg* prediction_arg, + GraphAugmenter::GraphDefs& graph_defs) { + ORT_ENFORCE(prediction_arg != nullptr, "GetMaskedPredictionTypeProto's prediction_arg is nullptr"); + const auto* logits_type_proto = prediction_arg->TypeAsProto(); + const auto& dims = logits_type_proto->tensor_type().shape().dim(); + + TypeProto* type_proto = graph_defs.CreateTypeProto(); + type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape(); + // Batch size. + target_shape->add_dim()->CopyFrom(dims[0]); + // Prediction count. + target_shape->add_dim()->set_dim_param("dynamic_prediction_count"); + // Vocab size. + target_shape->add_dim()->CopyFrom(dims[2]); + + return type_proto; +} + +TypeProto* BertLossLegacy::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) { + return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT); +} + +GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) { + const std::string& total_loss = loss_func_info.loss_name; + const VectorString& args = loss_func_info.loss_builder_args; + ORT_ENFORCE(args.size() == 8, " Invalid loss_func_info for BertLoss."); + const std::string& prediction_masked_lm = args[0]; + const std::string& prediction_next_sentence = args[1]; + const std::string& masked_lm_positions = args[2]; + const std::string& masked_lm_ids = args[3]; + const std::string& masked_lm_weights = args[4]; + const std::string& next_sentence_labels = args[5]; + const std::string& mlm_loss = args[6]; + const std::string& nsp_loss = args[7]; + + std::vector new_nodes; + GraphAugmenter::GraphDefs graph_defs; + // LabelSoftmaxCrossEntropy for masked_lm + { + const NodeArg* prediction_arg = graph.GetNodeArg(prediction_masked_lm); + ORT_ENFORCE(prediction_arg != nullptr, + "Masked_ML prediction arg ", prediction_masked_lm, " is not found in the graph."); + TypeProto* masked_lm_int64_type_proto = GetMaskedLMTypeProto(prediction_arg, + ONNX_NAMESPACE::TensorProto_DataType_INT64, + graph_defs); + + new_nodes.emplace_back(NodeDef("Unsqueeze", + {ArgDef(masked_lm_positions, masked_lm_int64_type_proto)}, + {ArgDef("masked_lm_positions_unsqueezed")}, + {ONNX_NAMESPACE::MakeAttribute("axes", std::vector{static_cast(2)})}, + "Mask_LM_Positions_Unsqueezed")); + TypeProto* gathered_prediction_type_proto = GetGatheredPredictionTypeProto(prediction_arg, + graph_defs); + new_nodes.emplace_back(NodeDef(OpDef{"GatherND", kOnnxDomain, 12}, + {ArgDef(prediction_masked_lm), ArgDef("masked_lm_positions_unsqueezed")}, + {ArgDef("gathered_prediction", gathered_prediction_type_proto)}, + {ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast(1))}, + "GATHERED_LM")); + + TypeProto* masked_lm_float_type_proto = GetMaskedLMTypeProto(prediction_arg, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + graph_defs); + new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", + {ArgDef("gathered_prediction", gathered_prediction_type_proto), + ArgDef(masked_lm_ids, masked_lm_int64_type_proto), + ArgDef(masked_lm_weights, masked_lm_float_type_proto)}, // Inputs + {ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs + ArgDef("probability_lm", gathered_prediction_type_proto)}, + {ONNX_NAMESPACE::MakeAttribute("reduction", "mean")}, + "Masked_LM_Loss")); + } + + // LabelSoftmaxCrossEntropy for next_sentence + { + const NodeArg* ns_prediction_arg = graph.GetNodeArg(prediction_next_sentence); + ORT_ENFORCE(ns_prediction_arg != nullptr, + "Next sentence prediction arg ", prediction_next_sentence, " is not found in the graph."); + TypeProto* next_sentence_labels_type_proto = GetSparseTypeProto(ns_prediction_arg, + ONNX_NAMESPACE::TensorProto_DataType_INT64, + graph_defs); + + new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy", + {ArgDef(prediction_next_sentence), + ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs + {ArgDef(nsp_loss, GetLossTypeProto(graph_defs)), + ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs + {ONNX_NAMESPACE::MakeAttribute("reduction", "mean")}, + "Next_Sentence_Loss")); + } + + // Add + { + new_nodes.emplace_back(NodeDef("Add", // Op + { + ArgDef(mlm_loss), + ArgDef(nsp_loss) // Inputs + }, + { + ArgDef(total_loss, GetLossTypeProto(graph_defs)) // Outputs + }, + NodeAttributes(), + "Bert_Total_Loss" // name + )); + } + + graph_defs.AddNodeDefs(new_nodes); + graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels}); + graph_defs.AddGraphOutputs({mlm_loss, nsp_loss, total_loss}); + + return graph_defs; +} + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.h b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.h new file mode 100644 index 0000000000..0a40d79327 --- /dev/null +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include "orttraining/core/graph/loss_func/loss_func_common.h" + +namespace onnxruntime { +namespace training { + +struct BertLossLegacy : public ILossFunction { + GraphAugmenter::GraphDefs operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) override; + + private: + static TypeProto* GetMaskedLMTypeProto(const NodeArg* prediction_arg, + ONNX_NAMESPACE::TensorProto_DataType data_type, + GraphAugmenter::GraphDefs& graph_defs); + static TypeProto* GetGatheredPredictionTypeProto(const NodeArg* prediction_arg, + GraphAugmenter::GraphDefs& graph_defs); + static TypeProto* GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs); +}; + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/loss_function_registry.cc b/orttraining/orttraining/core/graph/loss_function_registry.cc index e95ea9e3b6..ddc1eb2f3d 100644 --- a/orttraining/orttraining/core/graph/loss_function_registry.cc +++ b/orttraining/orttraining/core/graph/loss_function_registry.cc @@ -5,6 +5,7 @@ #include "loss_function_builder.h" #include "loss_func/mean_squared_error.h" #include "loss_func/bert_loss.h" +#include "loss_func/bert_loss_legacy.h" #include "loss_func/softmax_cross_entropy.h" namespace onnxruntime { @@ -56,6 +57,7 @@ void LossFunctionRegistry::RegisterNonOperatorLossFunctions() { // Register non-operator loss functions here. REGISTER_NON_OPERATOR_LOSS_FUNCTION(MeanSquaredError); REGISTER_NON_OPERATOR_LOSS_FUNCTION(BertLoss); + REGISTER_NON_OPERATOR_LOSS_FUNCTION(BertLossLegacy); REGISTER_NON_OPERATOR_LOSS_FUNCTION(SoftmaxCrossEntropy); REGISTER_NON_OPERATOR_LOSS_FUNCTION(SparseSoftmaxCrossEntropy); REGISTER_NON_OPERATOR_LOSS_FUNCTION(SoftmaxCrossEntropyLoss);