mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
Donut: fix generate call from local path (#31470)
* local donut path fix * engrish * Update src/transformers/generation/utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
76289fbc7c
commit
cd71f9381b
1 changed files with 6 additions and 2 deletions
|
|
@ -575,8 +575,12 @@ class GenerationMixin:
|
|||
# no user input -> use decoder_start_token_id as decoder_input_ids
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = decoder_start_token_id
|
||||
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
|
||||
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
|
||||
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
|
||||
# original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
|
||||
# See: https://github.com/huggingface/transformers/pull/31470
|
||||
elif "donut" in self.__class__.__name__.lower() or (
|
||||
self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
|
||||
):
|
||||
pass
|
||||
elif self.config.model_type in ["whisper"]:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue