mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
add bert loss legacy back (#5224)
This commit is contained in:
parent
16d35266ab
commit
d9ecc0cebf
3 changed files with 171 additions and 0 deletions
145
orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc
Normal file
145
orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <vector>
|
||||
#include "orttraining/core/graph/loss_func/bert_loss_legacy.h"
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
TypeProto* BertLossLegacy::GetMaskedLMTypeProto(const NodeArg* prediction_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType data_type,
|
||||
GraphAugmenter::GraphDefs& graph_defs) {
|
||||
ORT_ENFORCE(prediction_arg != nullptr, "GetMaskedPredictionTypeProto's prediction_arg is nullptr");
|
||||
const auto* logits_type_proto = prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
|
||||
TypeProto* type_proto = graph_defs.CreateTypeProto();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(data_type);
|
||||
|
||||
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
|
||||
// Batch size.
|
||||
target_shape->add_dim()->CopyFrom(dims[0]);
|
||||
// Prediction count.
|
||||
target_shape->add_dim()->set_dim_param("dynamic_prediction_count");
|
||||
|
||||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLossLegacy::GetGatheredPredictionTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs) {
|
||||
ORT_ENFORCE(prediction_arg != nullptr, "GetMaskedPredictionTypeProto's prediction_arg is nullptr");
|
||||
const auto* logits_type_proto = prediction_arg->TypeAsProto();
|
||||
const auto& dims = logits_type_proto->tensor_type().shape().dim();
|
||||
|
||||
TypeProto* type_proto = graph_defs.CreateTypeProto();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
||||
auto* target_shape = type_proto->mutable_tensor_type()->mutable_shape();
|
||||
// Batch size.
|
||||
target_shape->add_dim()->CopyFrom(dims[0]);
|
||||
// Prediction count.
|
||||
target_shape->add_dim()->set_dim_param("dynamic_prediction_count");
|
||||
// Vocab size.
|
||||
target_shape->add_dim()->CopyFrom(dims[2]);
|
||||
|
||||
return type_proto;
|
||||
}
|
||||
|
||||
TypeProto* BertLossLegacy::GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs) {
|
||||
return graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
}
|
||||
|
||||
GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) {
|
||||
const std::string& total_loss = loss_func_info.loss_name;
|
||||
const VectorString& args = loss_func_info.loss_builder_args;
|
||||
ORT_ENFORCE(args.size() == 8, " Invalid loss_func_info for BertLoss.");
|
||||
const std::string& prediction_masked_lm = args[0];
|
||||
const std::string& prediction_next_sentence = args[1];
|
||||
const std::string& masked_lm_positions = args[2];
|
||||
const std::string& masked_lm_ids = args[3];
|
||||
const std::string& masked_lm_weights = args[4];
|
||||
const std::string& next_sentence_labels = args[5];
|
||||
const std::string& mlm_loss = args[6];
|
||||
const std::string& nsp_loss = args[7];
|
||||
|
||||
std::vector<NodeDef> new_nodes;
|
||||
GraphAugmenter::GraphDefs graph_defs;
|
||||
// LabelSoftmaxCrossEntropy for masked_lm
|
||||
{
|
||||
const NodeArg* prediction_arg = graph.GetNodeArg(prediction_masked_lm);
|
||||
ORT_ENFORCE(prediction_arg != nullptr,
|
||||
"Masked_ML prediction arg ", prediction_masked_lm, " is not found in the graph.");
|
||||
TypeProto* masked_lm_int64_type_proto = GetMaskedLMTypeProto(prediction_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
graph_defs);
|
||||
|
||||
new_nodes.emplace_back(NodeDef("Unsqueeze",
|
||||
{ArgDef(masked_lm_positions, masked_lm_int64_type_proto)},
|
||||
{ArgDef("masked_lm_positions_unsqueezed")},
|
||||
{ONNX_NAMESPACE::MakeAttribute("axes", std::vector<int64_t>{static_cast<int64_t>(2)})},
|
||||
"Mask_LM_Positions_Unsqueezed"));
|
||||
TypeProto* gathered_prediction_type_proto = GetGatheredPredictionTypeProto(prediction_arg,
|
||||
graph_defs);
|
||||
new_nodes.emplace_back(NodeDef(OpDef{"GatherND", kOnnxDomain, 12},
|
||||
{ArgDef(prediction_masked_lm), ArgDef("masked_lm_positions_unsqueezed")},
|
||||
{ArgDef("gathered_prediction", gathered_prediction_type_proto)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("batch_dims", static_cast<int64_t>(1))},
|
||||
"GATHERED_LM"));
|
||||
|
||||
TypeProto* masked_lm_float_type_proto = GetMaskedLMTypeProto(prediction_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
graph_defs);
|
||||
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy",
|
||||
{ArgDef("gathered_prediction", gathered_prediction_type_proto),
|
||||
ArgDef(masked_lm_ids, masked_lm_int64_type_proto),
|
||||
ArgDef(masked_lm_weights, masked_lm_float_type_proto)}, // Inputs
|
||||
{ArgDef(mlm_loss, GetLossTypeProto(graph_defs)), // Outputs
|
||||
ArgDef("probability_lm", gathered_prediction_type_proto)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
|
||||
"Masked_LM_Loss"));
|
||||
}
|
||||
|
||||
// LabelSoftmaxCrossEntropy for next_sentence
|
||||
{
|
||||
const NodeArg* ns_prediction_arg = graph.GetNodeArg(prediction_next_sentence);
|
||||
ORT_ENFORCE(ns_prediction_arg != nullptr,
|
||||
"Next sentence prediction arg ", prediction_next_sentence, " is not found in the graph.");
|
||||
TypeProto* next_sentence_labels_type_proto = GetSparseTypeProto(ns_prediction_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
graph_defs);
|
||||
|
||||
new_nodes.emplace_back(NodeDef("SparseSoftmaxCrossEntropy",
|
||||
{ArgDef(prediction_next_sentence),
|
||||
ArgDef(next_sentence_labels, next_sentence_labels_type_proto)}, // Inputs
|
||||
{ArgDef(nsp_loss, GetLossTypeProto(graph_defs)),
|
||||
ArgDef("probability_ns", ns_prediction_arg->TypeAsProto())}, // Outputs
|
||||
{ONNX_NAMESPACE::MakeAttribute("reduction", "mean")},
|
||||
"Next_Sentence_Loss"));
|
||||
}
|
||||
|
||||
// Add
|
||||
{
|
||||
new_nodes.emplace_back(NodeDef("Add", // Op
|
||||
{
|
||||
ArgDef(mlm_loss),
|
||||
ArgDef(nsp_loss) // Inputs
|
||||
},
|
||||
{
|
||||
ArgDef(total_loss, GetLossTypeProto(graph_defs)) // Outputs
|
||||
},
|
||||
NodeAttributes(),
|
||||
"Bert_Total_Loss" // name
|
||||
));
|
||||
}
|
||||
|
||||
graph_defs.AddNodeDefs(new_nodes);
|
||||
graph_defs.AddGraphInputs({masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels});
|
||||
graph_defs.AddGraphOutputs({mlm_loss, nsp_loss, total_loss});
|
||||
|
||||
return graph_defs;
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "orttraining/core/graph/loss_func/loss_func_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
struct BertLossLegacy : public ILossFunction {
|
||||
GraphAugmenter::GraphDefs operator()(const Graph& graph, const LossFunctionInfo& loss_func_info) override;
|
||||
|
||||
private:
|
||||
static TypeProto* GetMaskedLMTypeProto(const NodeArg* prediction_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType data_type,
|
||||
GraphAugmenter::GraphDefs& graph_defs);
|
||||
static TypeProto* GetGatheredPredictionTypeProto(const NodeArg* prediction_arg,
|
||||
GraphAugmenter::GraphDefs& graph_defs);
|
||||
static TypeProto* GetLossTypeProto(GraphAugmenter::GraphDefs& graph_defs);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
#include "loss_function_builder.h"
|
||||
#include "loss_func/mean_squared_error.h"
|
||||
#include "loss_func/bert_loss.h"
|
||||
#include "loss_func/bert_loss_legacy.h"
|
||||
#include "loss_func/softmax_cross_entropy.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -56,6 +57,7 @@ void LossFunctionRegistry::RegisterNonOperatorLossFunctions() {
|
|||
// Register non-operator loss functions here.
|
||||
REGISTER_NON_OPERATOR_LOSS_FUNCTION(MeanSquaredError);
|
||||
REGISTER_NON_OPERATOR_LOSS_FUNCTION(BertLoss);
|
||||
REGISTER_NON_OPERATOR_LOSS_FUNCTION(BertLossLegacy);
|
||||
REGISTER_NON_OPERATOR_LOSS_FUNCTION(SoftmaxCrossEntropy);
|
||||
REGISTER_NON_OPERATOR_LOSS_FUNCTION(SparseSoftmaxCrossEntropy);
|
||||
REGISTER_NON_OPERATOR_LOSS_FUNCTION(SoftmaxCrossEntropyLoss);
|
||||
|
|
|
|||
Loading…
Reference in a new issue