diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py index 63b9fe769..f07a5bf5a 100644 --- a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -2227,7 +2227,7 @@ class Kosmos2_5ForConditionalGeneration(Kosmos2_5PreTrainedModel, GenerationMixi **model_kwargs, ) - if past_key_values is None: + if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need `flattened_patches` to be passed to model model_inputs["flattened_patches"] = flattened_patches