diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5b75a2519..3abd604cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1394,7 +1394,7 @@ class GenerationMixin: return model_kwargs past_length = 0 - if "past_key_values" in model_kwargs: + if model_kwargs.get("past_key_values") is not None: if isinstance(model_kwargs["past_key_values"], Cache): past_length = model_kwargs["past_key_values"].get_seq_length() else: