mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Fix ReduceSum in attention fusion (#15047)
Fix https://github.com/microsoft/onnxruntime/issues/14959. ReduceSum-13 move axes from attribute to node input.
This commit is contained in:
parent
c70838cbbb
commit
bdfdebfca7
1 changed files with 29 additions and 7 deletions
|
|
@ -28,6 +28,7 @@ class AttentionMask:
|
|||
self.mask_casted = {}
|
||||
self.utils = FusionUtils(model)
|
||||
self.mask_format = AttentionMaskFormat.MaskIndexEnd
|
||||
self.opset_version = model.get_opset_version()
|
||||
|
||||
def set_mask_format(self, mask_format: AttentionMaskFormat):
|
||||
self.mask_format = mask_format
|
||||
|
|
@ -65,13 +66,34 @@ class AttentionMask:
|
|||
|
||||
# Add a mask processing node to convert attention mask to mask index (1D)
|
||||
output_name = self.model.create_node_name("mask_index")
|
||||
mask_index_node = helper.make_node(
|
||||
"ReduceSum",
|
||||
inputs=[input_name],
|
||||
outputs=[output_name],
|
||||
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
|
||||
)
|
||||
mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
|
||||
if self.opset_version < 13:
|
||||
mask_index_node = helper.make_node(
|
||||
"ReduceSum",
|
||||
inputs=[input_name],
|
||||
outputs=[output_name],
|
||||
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
|
||||
)
|
||||
mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
|
||||
else:
|
||||
# ReduceSum-13: axes is moved from attribute to input
|
||||
axes_name = "ort_const_1_reduce_sum_axes"
|
||||
if self.model.get_initializer(axes_name) is None:
|
||||
self.model.add_initializer(
|
||||
helper.make_tensor(
|
||||
name=axes_name,
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[1],
|
||||
vals=[1],
|
||||
)
|
||||
)
|
||||
mask_index_node = helper.make_node(
|
||||
"ReduceSum",
|
||||
inputs=[input_name, axes_name],
|
||||
outputs=[output_name],
|
||||
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
|
||||
)
|
||||
mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
|
||||
|
||||
self.model.add_node(mask_index_node)
|
||||
|
||||
self.mask_indice[input] = output_name
|
||||
|
|
|
|||
Loading…
Reference in a new issue