mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Generate: Simplify is_pad_token_not_equal_to_eos_token_id (#18933)
This commit is contained in:
parent
85125fcffd
commit
f1a6df3210
2 changed files with 3 additions and 6 deletions
|
|
@ -1739,9 +1739,7 @@ class TFGenerationMixin:
|
|||
) -> tf.Tensor:
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||
)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
|
||||
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||
|
|
|
|||
|
|
@ -495,9 +495,8 @@ class GenerationMixin:
|
|||
) -> torch.LongTensor:
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||
)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
|
||||
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||
return inputs.ne(pad_token_id).long()
|
||||
|
|
|
|||
Loading…
Reference in a new issue