Fix bug in Attention Fusion (#13050)

This commit is contained in:
Hariharan Seshadri 2022-09-22 13:46:59 -07:00 committed by GitHub
parent cccbe90764
commit 057567f39f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -242,7 +242,7 @@ class FusionAttention(Fusion):
# For 2d weights, the shapes would be [in_size, out_size].
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
qw_out_size = np.prod(qw.shape[1:])
kw_out_size = np.prod(qw.shape[1:])
kw_out_size = np.prod(kw.shape[1:])
vw_out_size = np.prod(vw.shape[1:])
qkv_weight_dim = 0