From e467d78a117a6c7e5d3fb5c09a9ebc949f52a2de Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Fri, 9 Jul 2021 09:24:43 -0700 Subject: [PATCH] fix a typo (#8334) Co-authored-by: Cheng Tang --- .../orttraining/core/graph/loss_func/bert_loss_legacy.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc index 38e464e88b..05312a7510 100644 --- a/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc +++ b/orttraining/orttraining/core/graph/loss_func/bert_loss_legacy.cc @@ -94,7 +94,7 @@ GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const L {ONNX_NAMESPACE::MakeAttribute("axes", std::vector{static_cast(2)})}, "Mask_LM_Positions_Unsqueezed")); } else { - auto t_proto = ONNX_NAMESPACE::ToTensor(1); + auto t_proto = ONNX_NAMESPACE::ToTensor(2); TypeProto* int64_t_proto = graph_defs.CreateTypeProto(); int64_t_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64); @@ -105,7 +105,7 @@ GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const L {ONNX_NAMESPACE::MakeAttribute("value", t_proto)})); new_nodes.emplace_back(NodeDef("Unsqueeze", {ArgDef(masked_lm_positions, masked_lm_int64_type_proto), - ArgDef("two_constant", int64_t_proto)}, + ArgDef(two_constant, int64_t_proto)}, {ArgDef("masked_lm_positions_unsqueezed")}, NodeAttributes(), "Mask_LM_Positions_Unsqueezed"));