mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Ability to fuse non-square (pruned) attention weights for BERT-like models (#6850)
This commit is contained in:
parent
f986ffcb5f
commit
9126faa35b
1 changed files with 14 additions and 11 deletions
|
|
@ -104,25 +104,28 @@ class FusionAttention(Fusion):
|
|||
if not (k_weight and v_weight and q_bias and k_bias):
|
||||
return None
|
||||
|
||||
assert (self.hidden_size % self.num_heads) == 0
|
||||
head_hidden_size = self.hidden_size // self.num_heads
|
||||
|
||||
qw = numpy_helper.to_array(q_weight)
|
||||
assert qw.shape == (self.hidden_size, self.hidden_size)
|
||||
|
||||
kw = numpy_helper.to_array(k_weight)
|
||||
assert kw.shape == (self.hidden_size, self.hidden_size)
|
||||
|
||||
vw = numpy_helper.to_array(v_weight)
|
||||
assert vw.shape == (self.hidden_size, self.hidden_size)
|
||||
|
||||
# Check if all matrices have the same shape
|
||||
assert qw.shape == kw.shape == vw.shape
|
||||
|
||||
# All the matrices have the same shape (in_size, out_size)
|
||||
in_size, out_size = qw.shape
|
||||
qkv_weight = np.stack((qw, kw, vw), axis=-2)
|
||||
|
||||
qb = numpy_helper.to_array(q_bias)
|
||||
assert qb.shape == (self.hidden_size, )
|
||||
assert qb.shape == (out_size, )
|
||||
|
||||
kb = numpy_helper.to_array(k_bias)
|
||||
assert kb.shape == (self.hidden_size, )
|
||||
assert kb.shape == (out_size, )
|
||||
|
||||
vb = numpy_helper.to_array(v_bias)
|
||||
assert vb.shape == (self.hidden_size, )
|
||||
assert vb.shape == (out_size, )
|
||||
|
||||
qkv_bias = np.stack((qb, kb, vb), axis=-2)
|
||||
|
||||
|
|
@ -130,7 +133,7 @@ class FusionAttention(Fusion):
|
|||
|
||||
weight = helper.make_tensor(name=attention_node_name + '_qkv_weight',
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[self.hidden_size, 3 * self.hidden_size],
|
||||
dims=[in_size, 3 * out_size],
|
||||
vals=qkv_weight.flatten().tolist())
|
||||
# Sometimes weights and bias are stored in fp16
|
||||
if q_weight.data_type == 10:
|
||||
|
|
@ -139,7 +142,7 @@ class FusionAttention(Fusion):
|
|||
|
||||
bias = helper.make_tensor(name=attention_node_name + '_qkv_bias',
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[3 * self.hidden_size],
|
||||
dims=[3 * out_size],
|
||||
vals=qkv_bias.flatten().tolist())
|
||||
if q_bias.data_type == 10:
|
||||
bias.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(bias).astype(np.float16), bias.name))
|
||||
|
|
@ -154,7 +157,7 @@ class FusionAttention(Fusion):
|
|||
outputs=[output],
|
||||
name=attention_node_name)
|
||||
attention_node.domain = "com.microsoft"
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
|
||||
attention_node.attribute.extend([helper.make_attribute("num_heads", out_size // head_hidden_size)])
|
||||
|
||||
return attention_node
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue