diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index b6d36d6e08..b549980719 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -242,7 +242,7 @@ class FusionAttention(Fusion): # For 2d weights, the shapes would be [in_size, out_size]. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size qw_out_size = np.prod(qw.shape[1:]) - kw_out_size = np.prod(qw.shape[1:]) + kw_out_size = np.prod(kw.shape[1:]) vw_out_size = np.prod(vw.shape[1:]) qkv_weight_dim = 0