diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc index 38e464e88b..05312a7510 100644 --- a/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc @@ -94,7 +94,7 @@ GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const L {ONNX_NAMESPACE::MakeAttribute("axes", std::vector{static_cast(2)})}, "Mask_LM_Positions_Unsqueezed")); } else { - auto t_proto = ONNX_NAMESPACE::ToTensor(1); + auto t_proto = ONNX_NAMESPACE::ToTensor(2); TypeProto* int64_t_proto = graph_defs.CreateTypeProto(); int64_t_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64); @@ -105,7 +105,7 @@ GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const L {ONNX_NAMESPACE::MakeAttribute("value", t_proto)})); new_nodes.emplace_back(NodeDef("Unsqueeze", {ArgDef(masked_lm_positions, masked_lm_int64_type_proto), - ArgDef("two_constant", int64_t_proto)}, + ArgDef(two_constant, int64_t_proto)}, {ArgDef("masked_lm_positions_unsqueezed")}, NodeAttributes(), "Mask_LM_Positions_Unsqueezed"));