Fix ReduceSum in attention fusion (#15047)

Fix https://github.com/microsoft/onnxruntime/issues/14959.
ReduceSum-13 move axes from attribute to node input.
This commit is contained in:
Tianlei Wu 2023-03-14 20:34:17 -07:00 committed by GitHub
parent c70838cbbb
commit bdfdebfca7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -28,6 +28,7 @@ class AttentionMask:
self.mask_casted = {}
self.utils = FusionUtils(model)
self.mask_format = AttentionMaskFormat.MaskIndexEnd
self.opset_version = model.get_opset_version()
def set_mask_format(self, mask_format: AttentionMaskFormat):
self.mask_format = mask_format
@ -65,13 +66,34 @@ class AttentionMask:
# Add a mask processing node to convert attention mask to mask index (1D)
output_name = self.model.create_node_name("mask_index")
mask_index_node = helper.make_node(
"ReduceSum",
inputs=[input_name],
outputs=[output_name],
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
)
mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
if self.opset_version < 13:
mask_index_node = helper.make_node(
"ReduceSum",
inputs=[input_name],
outputs=[output_name],
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
)
mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
else:
# ReduceSum-13: axes is moved from attribute to input
axes_name = "ort_const_1_reduce_sum_axes"
if self.model.get_initializer(axes_name) is None:
self.model.add_initializer(
helper.make_tensor(
name=axes_name,
data_type=TensorProto.INT64,
dims=[1],
vals=[1],
)
)
mask_index_node = helper.make_node(
"ReduceSum",
inputs=[input_name, axes_name],
outputs=[output_name],
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
)
mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
self.model.add_node(mask_index_node)
self.mask_indice[input] = output_name