Ability to fuse non-square (pruned) attention weights for BERT-like models (#6850)

This commit is contained in:
Funtowicz Morgan 2021-03-05 02:08:08 +01:00 committed by GitHub
parent f986ffcb5f
commit 9126faa35b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -104,25 +104,28 @@ class FusionAttention(Fusion):
if not (k_weight and v_weight and q_bias and k_bias):
return None
assert (self.hidden_size % self.num_heads) == 0
head_hidden_size = self.hidden_size // self.num_heads
qw = numpy_helper.to_array(q_weight)
assert qw.shape == (self.hidden_size, self.hidden_size)
kw = numpy_helper.to_array(k_weight)
assert kw.shape == (self.hidden_size, self.hidden_size)
vw = numpy_helper.to_array(v_weight)
assert vw.shape == (self.hidden_size, self.hidden_size)
# Check if all matrices have the same shape
assert qw.shape == kw.shape == vw.shape
# All the matrices have the same shape (in_size, out_size)
in_size, out_size = qw.shape
qkv_weight = np.stack((qw, kw, vw), axis=-2)
qb = numpy_helper.to_array(q_bias)
assert qb.shape == (self.hidden_size, )
assert qb.shape == (out_size, )
kb = numpy_helper.to_array(k_bias)
assert kb.shape == (self.hidden_size, )
assert kb.shape == (out_size, )
vb = numpy_helper.to_array(v_bias)
assert vb.shape == (self.hidden_size, )
assert vb.shape == (out_size, )
qkv_bias = np.stack((qb, kb, vb), axis=-2)
@ -130,7 +133,7 @@ class FusionAttention(Fusion):
weight = helper.make_tensor(name=attention_node_name + '_qkv_weight',
data_type=TensorProto.FLOAT,
dims=[self.hidden_size, 3 * self.hidden_size],
dims=[in_size, 3 * out_size],
vals=qkv_weight.flatten().tolist())
# Sometimes weights and bias are stored in fp16
if q_weight.data_type == 10:
@ -139,7 +142,7 @@ class FusionAttention(Fusion):
bias = helper.make_tensor(name=attention_node_name + '_qkv_bias',
data_type=TensorProto.FLOAT,
dims=[3 * self.hidden_size],
dims=[3 * out_size],
vals=qkv_bias.flatten().tolist())
if q_bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(numpy_helper.to_array(bias).astype(np.float16), bias.name))
@ -154,7 +157,7 @@ class FusionAttention(Fusion):
outputs=[output],
name=attention_node_name)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
attention_node.attribute.extend([helper.make_attribute("num_heads", out_size // head_hidden_size)])
return attention_node