diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 65ba83b8d9..c05424e39c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -28,6 +28,7 @@ class AttentionMask: self.mask_casted = {} self.utils = FusionUtils(model) self.mask_format = AttentionMaskFormat.MaskIndexEnd + self.opset_version = model.get_opset_version() def set_mask_format(self, mask_format: AttentionMaskFormat): self.mask_format = mask_format @@ -65,13 +66,34 @@ class AttentionMask: # Add a mask processing node to convert attention mask to mask index (1D) output_name = self.model.create_node_name("mask_index") - mask_index_node = helper.make_node( - "ReduceSum", - inputs=[input_name], - outputs=[output_name], - name=self.model.create_node_name("ReduceSum", "MaskReduceSum"), - ) - mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)]) + if self.opset_version < 13: + mask_index_node = helper.make_node( + "ReduceSum", + inputs=[input_name], + outputs=[output_name], + name=self.model.create_node_name("ReduceSum", "MaskReduceSum"), + ) + mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)]) + else: + # ReduceSum-13: axes is moved from attribute to input + axes_name = "ort_const_1_reduce_sum_axes" + if self.model.get_initializer(axes_name) is None: + self.model.add_initializer( + helper.make_tensor( + name=axes_name, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + ) + ) + mask_index_node = helper.make_node( + "ReduceSum", + inputs=[input_name, axes_name], + outputs=[output_name], + name=self.model.create_node_name("ReduceSum", "MaskReduceSum"), + ) + mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)]) + self.model.add_node(mask_index_node) self.mask_indice[input] = output_name