From bdfdebfca737b235525a05ee6b574ab2e88a6aa6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 14 Mar 2023 20:34:17 -0700 Subject: [PATCH] Fix ReduceSum in attention fusion (#15047) Fix https://github.com/microsoft/onnxruntime/issues/14959. ReduceSum-13 move axes from attribute to node input. --- .../tools/transformers/fusion_attention.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 65ba83b8d9..c05424e39c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -28,6 +28,7 @@ class AttentionMask: self.mask_casted = {} self.utils = FusionUtils(model) self.mask_format = AttentionMaskFormat.MaskIndexEnd + self.opset_version = model.get_opset_version() def set_mask_format(self, mask_format: AttentionMaskFormat): self.mask_format = mask_format @@ -65,13 +66,34 @@ class AttentionMask: # Add a mask processing node to convert attention mask to mask index (1D) output_name = self.model.create_node_name("mask_index") - mask_index_node = helper.make_node( - "ReduceSum", - inputs=[input_name], - outputs=[output_name], - name=self.model.create_node_name("ReduceSum", "MaskReduceSum"), - ) - mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)]) + if self.opset_version < 13: + mask_index_node = helper.make_node( + "ReduceSum", + inputs=[input_name], + outputs=[output_name], + name=self.model.create_node_name("ReduceSum", "MaskReduceSum"), + ) + mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)]) + else: + # ReduceSum-13: axes is moved from attribute to input + axes_name = "ort_const_1_reduce_sum_axes" + if self.model.get_initializer(axes_name) is None: + self.model.add_initializer( + helper.make_tensor( + name=axes_name, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + ) + ) + mask_index_node = helper.make_node( + "ReduceSum", + inputs=[input_name, axes_name], + outputs=[output_name], + name=self.model.create_node_name("ReduceSum", "MaskReduceSum"), + ) + mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)]) + self.model.add_node(mask_index_node) self.mask_indice[input] = output_name