From 862bc8c7a032ecb7fe6098d5feb56981ff17bd6f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 19 Jul 2021 17:24:10 -0700 Subject: [PATCH] shape infer for present output of Attention op (#8430) --- onnxruntime/python/tools/symbolic_shape_infer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index fdccf86fd4..26b5da9d10 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1475,7 +1475,6 @@ class SymbolicShapeInference: vi.CopyFrom(new_vi) def _infer_Attention(self, node): - #TODO: shape inference for the other output (present). shape = self._get_shape(node, 0) shape_bias = self._get_shape(node, 2) assert len(shape) == 3 and len(shape_bias) == 1 @@ -1489,6 +1488,17 @@ class SymbolicShapeInference: vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + if len(node.output) > 1: + # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) + # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) + # present shape: (2, batch_size, num_heads, total_sequence_length, head_size) + past_shape = self._get_shape(node, 4) + mask_shape = self._get_shape(node, 3) + if len(past_shape) == 5 and len(mask_shape) in [2, 3]: + past_shape[3] = mask_shape[-1] + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + def _infer_BiasGelu(self, node): self._propagate_shape_and_type(node)