mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-17 01:44:45 +00:00
shape infer for present output of Attention op (#8430)
This commit is contained in:
parent
0f989c6162
commit
862bc8c7a0
1 changed files with 11 additions and 1 deletions
|
|
@ -1475,7 +1475,6 @@ class SymbolicShapeInference:
|
|||
vi.CopyFrom(new_vi)
|
||||
|
||||
def _infer_Attention(self, node):
|
||||
#TODO: shape inference for the other output (present).
|
||||
shape = self._get_shape(node, 0)
|
||||
shape_bias = self._get_shape(node, 2)
|
||||
assert len(shape) == 3 and len(shape_bias) == 1
|
||||
|
|
@ -1489,6 +1488,17 @@ class SymbolicShapeInference:
|
|||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
|
||||
|
||||
if len(node.output) > 1:
|
||||
# past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
|
||||
# mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length)
|
||||
# present shape: (2, batch_size, num_heads, total_sequence_length, head_size)
|
||||
past_shape = self._get_shape(node, 4)
|
||||
mask_shape = self._get_shape(node, 3)
|
||||
if len(past_shape) == 5 and len(mask_shape) in [2, 3]:
|
||||
past_shape[3] = mask_shape[-1]
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
||||
|
||||
def _infer_BiasGelu(self, node):
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue