Donut: fix generate call from local path (#31470)

* local donut path fix

* engrish

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Joao Gante 2024-06-18 13:28:06 +01:00 committed by GitHub
parent 76289fbc7c
commit cd71f9381b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -575,8 +575,12 @@ class GenerationMixin:
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_start_token_id
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
# original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
# See: https://github.com/huggingface/transformers/pull/31470
elif "donut" in self.__class__.__name__.lower() or (
self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
):
pass
elif self.config.model_type in ["whisper"]:
pass