fix(FA): QKV not being casted to target_dtype for FA with dpo lora (#35834)

fix(FA): QKV not being casted to target_dtype due to dtype check
This commit is contained in:
NanoCode012 2025-01-28 23:06:56 +07:00 committed by GitHub
parent ece8c42488
commit 478c4f2d0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -209,7 +209,7 @@ def fa_peft_integration_check(
if target_dtype is None:
return query, key, value
input_dtype = value.dtype
input_dtype = query.dtype
if input_dtype == torch.float32:
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"