mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
fix assert error in attention fusion script (#17375)
Add a check of num_heads and hidden_size to avoid assert error (https://github.com/microsoft/onnxruntime/issues/17254)
This commit is contained in:
parent
e23f16adbf
commit
e745575187
1 changed files with 7 additions and 0 deletions
|
|
@ -1166,6 +1166,13 @@ class FusionAttention(Fusion):
|
|||
attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
|
||||
|
||||
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
||||
if q_num_heads <= 0 or q_hidden_size <= 0:
|
||||
logger.warning(
|
||||
"Failed to detect num_heads and hidden_size for Attention fusion. "
|
||||
"Please specify those parameters in argument."
|
||||
)
|
||||
return
|
||||
|
||||
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
||||
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
|
||||
new_node = self.create_attention_node(
|
||||
|
|
|
|||
Loading…
Reference in a new issue