mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
parent
598454bb5f
commit
e467d78a11
1 changed files with 2 additions and 2 deletions
|
|
@ -94,7 +94,7 @@ GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const L
|
|||
{ONNX_NAMESPACE::MakeAttribute("axes", std::vector<int64_t>{static_cast<int64_t>(2)})},
|
||||
"Mask_LM_Positions_Unsqueezed"));
|
||||
} else {
|
||||
auto t_proto = ONNX_NAMESPACE::ToTensor<int64_t>(1);
|
||||
auto t_proto = ONNX_NAMESPACE::ToTensor<int64_t>(2);
|
||||
TypeProto* int64_t_proto = graph_defs.CreateTypeProto();
|
||||
int64_t_proto->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64);
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ GraphAugmenter::GraphDefs BertLossLegacy::operator()(const Graph& graph, const L
|
|||
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)}));
|
||||
new_nodes.emplace_back(NodeDef("Unsqueeze",
|
||||
{ArgDef(masked_lm_positions, masked_lm_int64_type_proto),
|
||||
ArgDef("two_constant", int64_t_proto)},
|
||||
ArgDef(two_constant, int64_t_proto)},
|
||||
{ArgDef("masked_lm_positions_unsqueezed")},
|
||||
NodeAttributes(),
|
||||
"Mask_LM_Positions_Unsqueezed"));
|
||||
|
|
|
|||
Loading…
Reference in a new issue