fix assert error in attention fusion script (#17375)

Add a check of num_heads and hidden_size to avoid assert error (https://github.com/microsoft/onnxruntime/issues/17254)
This commit is contained in:
Tianlei Wu 2023-09-01 08:18:50 -07:00 committed by GitHub
parent e23f16adbf
commit e745575187
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1166,6 +1166,13 @@ class FusionAttention(Fusion):
attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if q_num_heads <= 0 or q_hidden_size <= 0:
logger.warning(
"Failed to detect num_heads and hidden_size for Attention fusion. "
"Please specify those parameters in argument."
)
return
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
new_node = self.create_attention_node(