mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Generation: fix handling of special tokens (#31254)
* fix special tokens in generatioon * fix test * add warning * fix the check * warn once * fix
This commit is contained in:
parent
7729b77478
commit
5fabd1e83b
2 changed files with 29 additions and 30 deletions
|
|
@ -1436,23 +1436,6 @@ class GenerationMixin:
|
|||
self._cache.reset()
|
||||
return self._cache
|
||||
|
||||
def _get_decoder_start_token_id(
|
||||
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
|
||||
) -> int:
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id
|
||||
if decoder_start_token_id is not None
|
||||
else self.generation_config.decoder_start_token_id
|
||||
)
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
||||
|
||||
if decoder_start_token_id is not None:
|
||||
return decoder_start_token_id
|
||||
elif bos_token_id is not None:
|
||||
return bos_token_id
|
||||
else:
|
||||
return
|
||||
|
||||
def _supports_default_dynamic_cache(self) -> bool:
|
||||
"""
|
||||
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
|
||||
|
|
@ -1478,25 +1461,32 @@ class GenerationMixin:
|
|||
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
||||
"""
|
||||
|
||||
# Convert special tokens to tensors (if they exist)
|
||||
def _tensor_or_none(token, device=None):
|
||||
# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
|
||||
def _tensor_or_none(token_kwargs, token_self, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
token = token_kwargs if token_kwargs is not None else token_self
|
||||
if token is None or isinstance(token, torch.Tensor):
|
||||
return token
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
|
||||
if self.config.is_encoder_decoder:
|
||||
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
|
||||
generation_config.decoder_start_token_id, generation_config.bos_token_id
|
||||
)
|
||||
bos_token_id = _tensor_or_none(
|
||||
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
|
||||
)
|
||||
eos_token_id = _tensor_or_none(
|
||||
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
|
||||
)
|
||||
pad_token_id = _tensor_or_none(
|
||||
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
|
||||
)
|
||||
decoder_start_token_id = _tensor_or_none(
|
||||
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
|
||||
)
|
||||
|
||||
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
||||
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
|
||||
if self.config.is_encoder_decoder:
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
|
||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||
if eos_token_id is not None and eos_token_id.ndim == 0:
|
||||
|
|
@ -1512,6 +1502,15 @@ class GenerationMixin:
|
|||
pad_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
|
||||
|
||||
# we can't infer attn mask if pad token is set to be eos token in model's generation config
|
||||
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning_once(
|
||||
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
|
||||
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
|
||||
"to obtain reliable results."
|
||||
)
|
||||
|
||||
# Sanity checks/warnings
|
||||
if self.config.is_encoder_decoder and decoder_start_token_id is None:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ class GenerationIntegrationTestsMixin:
|
|||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
||||
model.generation_config.eos_token_id = None
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
|
|
@ -170,7 +171,6 @@ class GenerationIntegrationTestsMixin:
|
|||
input_ids=input_ids,
|
||||
max_new_tokens=5,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
|
@ -197,6 +197,7 @@ class GenerationIntegrationTestsMixin:
|
|||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
||||
model.generation_config.eos_token_id = None
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
|
|
@ -206,7 +207,6 @@ class GenerationIntegrationTestsMixin:
|
|||
input_ids=input_ids,
|
||||
max_new_tokens=5,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue