diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index da984c44ad..d2b6ebd83e 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -104,25 +104,28 @@ class FusionAttention(Fusion): if not (k_weight and v_weight and q_bias and k_bias): return None + assert (self.hidden_size % self.num_heads) == 0 + head_hidden_size = self.hidden_size // self.num_heads + qw = numpy_helper.to_array(q_weight) - assert qw.shape == (self.hidden_size, self.hidden_size) - kw = numpy_helper.to_array(k_weight) - assert kw.shape == (self.hidden_size, self.hidden_size) - vw = numpy_helper.to_array(v_weight) - assert vw.shape == (self.hidden_size, self.hidden_size) + # Check if all matrices have the same shape + assert qw.shape == kw.shape == vw.shape + + # All the matrices have the same shape (in_size, out_size) + in_size, out_size = qw.shape qkv_weight = np.stack((qw, kw, vw), axis=-2) qb = numpy_helper.to_array(q_bias) - assert qb.shape == (self.hidden_size, ) + assert qb.shape == (out_size, ) kb = numpy_helper.to_array(k_bias) - assert kb.shape == (self.hidden_size, ) + assert kb.shape == (out_size, ) vb = numpy_helper.to_array(v_bias) - assert vb.shape == (self.hidden_size, ) + assert vb.shape == (out_size, ) qkv_bias = np.stack((qb, kb, vb), axis=-2) @@ -130,7 +133,7 @@ class FusionAttention(Fusion): weight = helper.make_tensor(name=attention_node_name + '_qkv_weight', data_type=TensorProto.FLOAT, - dims=[self.hidden_size, 3 * self.hidden_size], + dims=[in_size, 3 * out_size], vals=qkv_weight.flatten().tolist()) # Sometimes weights and bias are stored in fp16 if q_weight.data_type == 10: @@ -139,7 +142,7 @@ class FusionAttention(Fusion): bias = helper.make_tensor(name=attention_node_name + '_qkv_bias', data_type=TensorProto.FLOAT, - dims=[3 * self.hidden_size], + dims=[3 * out_size], vals=qkv_bias.flatten().tolist()) if q_bias.data_type == 10: bias.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(bias).astype(np.float16), bias.name)) @@ -154,7 +157,7 @@ class FusionAttention(Fusion): outputs=[output], name=attention_node_name) attention_node.domain = "com.microsoft" - attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) + attention_node.attribute.extend([helper.make_attribute("num_heads", out_size // head_hidden_size)]) return attention_node