shape infer for present output of Attention op (#8430)

This commit is contained in:
Tianlei Wu 2021-07-19 17:24:10 -07:00 committed by GitHub
parent 0f989c6162
commit 862bc8c7a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1475,7 +1475,6 @@ class SymbolicShapeInference:
vi.CopyFrom(new_vi)
def _infer_Attention(self, node):
#TODO: shape inference for the other output (present).
shape = self._get_shape(node, 0)
shape_bias = self._get_shape(node, 2)
assert len(shape) == 3 and len(shape_bias) == 1
@ -1489,6 +1488,17 @@ class SymbolicShapeInference:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
if len(node.output) > 1:
# past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
# mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length)
# present shape: (2, batch_size, num_heads, total_sequence_length, head_size)
past_shape = self._get_shape(node, 4)
mask_shape = self._get_shape(node, 3)
if len(past_shape) == 5 and len(mask_shape) in [2, 3]:
past_shape[3] = mask_shape[-1]
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
def _infer_BiasGelu(self, node):
self._propagate_shape_and_type(node)