diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 6adda0036..08b1a7481 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -209,7 +209,7 @@ def fa_peft_integration_check( if target_dtype is None: return query, key, value - input_dtype = value.dtype + input_dtype = query.dtype if input_dtype == torch.float32: logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to"