mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
[Bart/Memory] Two separate, smaller decoder attention masks (#3371)
This commit is contained in:
parent
53fe733805
commit
3ee431dd4c
2 changed files with 71 additions and 82 deletions
|
|
@ -74,39 +74,37 @@ BART_INPUTS_DOCSTRING = r"""
|
|||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
|
||||
decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, 1, tgt_seq_len, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens and future tokens, as in the paper.
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
||||
See diagram 1 in the paper for more info on the default strategy
|
||||
"""
|
||||
LARGE_NEGATIVE = -1e8
|
||||
|
||||
|
||||
def invert_mask(attention_mask):
|
||||
assert attention_mask.dim() == 2
|
||||
return attention_mask.eq(0)
|
||||
|
||||
|
||||
def _prepare_bart_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None,
|
||||
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
|
||||
):
|
||||
"""Prepare masks that ignore padding tokens in the decoder and a causal lm mask for the decoder if
|
||||
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
|
||||
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
|
||||
Note: this is not called during generation
|
||||
"""
|
||||
pad_token_id = config.pad_token_id
|
||||
need_causal_mask = not config.output_past
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
|
||||
bsz, tgt_len = decoder_input_ids.size()[:2]
|
||||
if decoder_attn_mask is None:
|
||||
bsz, tgt_len = decoder_input_ids.size()
|
||||
if decoder_padding_mask is None:
|
||||
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
||||
if need_causal_mask:
|
||||
causal_lm_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1)
|
||||
else:
|
||||
causal_lm_mask = None
|
||||
new_shape = (bsz, tgt_len, tgt_len)
|
||||
# make it broadcastable so can just be added to the attention coefficients
|
||||
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
|
||||
if mask_dtype is not None:
|
||||
decoder_attn_mask = decoder_attn_mask.to(mask_dtype)
|
||||
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
|
||||
return decoder_input_ids, decoder_attn_mask
|
||||
else:
|
||||
decoder_padding_mask = invert_mask(decoder_padding_mask)
|
||||
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
|
||||
dtype=causal_mask_dtype, device=decoder_input_ids.device
|
||||
)
|
||||
return decoder_input_ids, decoder_padding_mask, causal_mask
|
||||
|
||||
|
||||
class PretrainedBartModel(PreTrainedModel):
|
||||
|
|
@ -130,12 +128,9 @@ class PretrainedBartModel(PreTrainedModel):
|
|||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids)
|
||||
dummy_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": input_ids.ne(pad_token),
|
||||
"input_ids": input_ids,
|
||||
"decoder_attention_mask": decoder_attn_mask,
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
|
|
@ -153,21 +148,6 @@ def _check_shapes(shape_1, shape2):
|
|||
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
||||
|
||||
|
||||
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
|
||||
"""Make one mask of shape (bsz, 1, tgt_len, src_len) """
|
||||
a = torch.zeros(targ_size) # targ_size is(bsz, tgt_len, src_len)
|
||||
b = torch.zeros(targ_size)
|
||||
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
|
||||
_check_shapes(key_padding_mask.shape, targ_size[:2])
|
||||
reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size)
|
||||
a[reshaped] = LARGE_NEGATIVE
|
||||
|
||||
if causal_lm_mask is not None: # (tgt_len, src_len) -> targ_size
|
||||
_check_shapes(causal_lm_mask.shape, targ_size[-2:])
|
||||
b = causal_lm_mask.unsqueeze(0).expand(*targ_size)
|
||||
return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,)
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids, pad_token_id):
|
||||
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
|
||||
prev_output_tokens = input_ids.clone()
|
||||
|
|
@ -281,8 +261,7 @@ class BartEncoder(nn.Module):
|
|||
"""
|
||||
# check attention mask and invert
|
||||
if attention_mask is not None:
|
||||
assert attention_mask.dim() == 2
|
||||
attention_mask = attention_mask.eq(0)
|
||||
attention_mask = invert_mask(attention_mask)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
embed_pos = self.embed_positions(input_ids)
|
||||
|
|
@ -339,7 +318,13 @@ class DecoderLayer(nn.Module):
|
|||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None,
|
||||
self,
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_attn_mask=None,
|
||||
layer_state=None,
|
||||
causal_mask=None,
|
||||
decoder_padding_mask=None,
|
||||
):
|
||||
residual = x
|
||||
|
||||
|
|
@ -347,7 +332,12 @@ class DecoderLayer(nn.Module):
|
|||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x, key=x, layer_state=layer_state, attn_mask=attention_mask, need_weights=self.output_attentions
|
||||
query=x,
|
||||
key=x,
|
||||
layer_state=layer_state,
|
||||
key_padding_mask=decoder_padding_mask,
|
||||
attn_mask=causal_mask,
|
||||
need_weights=self.output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
|
|
@ -412,7 +402,8 @@ class BartDecoder(nn.Module):
|
|||
input_ids,
|
||||
encoder_hidden_states,
|
||||
encoder_padding_mask,
|
||||
combined_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask,
|
||||
decoder_cached_states=None,
|
||||
generation_mode=False,
|
||||
**unused
|
||||
|
|
@ -437,8 +428,7 @@ class BartDecoder(nn.Module):
|
|||
"""
|
||||
# check attention mask and invert
|
||||
if encoder_padding_mask is not None:
|
||||
assert encoder_padding_mask.dim() == 2
|
||||
encoder_padding_mask = encoder_padding_mask.eq(0)
|
||||
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
|
||||
|
|
@ -458,7 +448,6 @@ class BartDecoder(nn.Module):
|
|||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
next_decoder_cache = []
|
||||
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
decoder_layer # type: DecoderLayer
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
|
|
@ -468,7 +457,12 @@ class BartDecoder(nn.Module):
|
|||
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask,
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_attn_mask=encoder_padding_mask,
|
||||
decoder_padding_mask=decoder_padding_mask,
|
||||
layer_state=layer_state,
|
||||
causal_mask=decoder_causal_mask,
|
||||
)
|
||||
|
||||
if self.output_past:
|
||||
|
|
@ -736,6 +730,8 @@ def _filter_out_falsey_values(tup) -> Tuple:
|
|||
|
||||
|
||||
# Public API
|
||||
def _get_shape(t):
|
||||
return getattr(t, "shape", None)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
|
@ -769,13 +765,16 @@ class BartModel(PretrainedBartModel):
|
|||
|
||||
# make masks if user doesn't supply
|
||||
if not generation_mode:
|
||||
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
|
||||
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
self.config,
|
||||
input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attn_mask=decoder_attention_mask,
|
||||
mask_dtype=self.shared.weight.dtype,
|
||||
decoder_padding_mask=decoder_attention_mask,
|
||||
causal_mask_dtype=self.shared.weight.dtype,
|
||||
)
|
||||
else:
|
||||
decoder_padding_mask, causal_mask = None, None
|
||||
|
||||
assert decoder_input_ids is not None
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
|
@ -785,7 +784,8 @@ class BartModel(PretrainedBartModel):
|
|||
decoder_input_ids,
|
||||
encoder_outputs[0],
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
generation_mode=generation_mode,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ if is_torch_available():
|
|||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
shift_tokens_right,
|
||||
invert_mask,
|
||||
_prepare_bart_decoder_inputs,
|
||||
LARGE_NEGATIVE,
|
||||
)
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
|
||||
|
|
@ -123,10 +123,9 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
def test_initialization_more(self):
|
||||
# (config, input_ids, token_type_ids, input_mask, *unused) = \
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, inputs_dict["input_ids"])
|
||||
model = BartModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
|
@ -142,9 +141,17 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
_check_var(model.encoder.layers[0].fc1)
|
||||
_check_var(model.encoder.embed_positions)
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, inputs_dict["input_ids"]
|
||||
)
|
||||
model = BartModel(config).to(torch_device).eval()
|
||||
|
||||
decoder_features_with_created_mask = model(**inputs_dict)[0]
|
||||
decoder_features_with_passed_mask = model(
|
||||
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
|
|
@ -238,7 +245,7 @@ class BartHeadTests(unittest.TestCase):
|
|||
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||
lm_model = BartForConditionalGeneration(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids)
|
||||
loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
|
@ -336,41 +343,23 @@ class BartHeadTests(unittest.TestCase):
|
|||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
def test_dummy_inputs(self):
|
||||
config, *_ = self._get_config_and_data(output_past=True)
|
||||
config, *_ = self._get_config_and_data()
|
||||
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||
model(**model.dummy_inputs)
|
||||
|
||||
def test_prepare_bart_decoder_inputs(self):
|
||||
config, *_ = self._get_config_and_data(output_past=False)
|
||||
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
|
||||
input_ids = _long_tensor(([4, 4, 2]))
|
||||
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
|
||||
ignore = LARGE_NEGATIVE
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0, ignore, ignore],
|
||||
[0, 0, ignore],
|
||||
[ignore, ignore, ignore], # never attend to the final token, because its pad
|
||||
]
|
||||
).to(input_ids.device)
|
||||
self.assertEqual(decoder_attn_mask.size(), (1, 1, 3, 3))
|
||||
self.assertTrue(torch.eq(expected_mask, decoder_attn_mask).all())
|
||||
|
||||
# Test no causal mask
|
||||
config, *_ = self._get_config_and_data(output_past=True)
|
||||
expected_just_padding_mask = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 0], [ignore, ignore, ignore]] # never attend to the final token, because its pad
|
||||
).to(input_ids.device)
|
||||
_, decoder_attn_mask_no_causal_mask = _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids)
|
||||
self.assertEqual(decoder_attn_mask_no_causal_mask.size(), (1, 1, 3, 3))
|
||||
self.assertTrue(torch.eq(expected_just_padding_mask, decoder_attn_mask_no_causal_mask).all())
|
||||
|
||||
decoder_input_ids = _long_tensor([[0, 26388, 4133, 2]])
|
||||
# Attend to everything if no pad tokens and no causal mask
|
||||
_, decoder_attn_mask_no_padding_no_causal_mask = _prepare_bart_decoder_inputs(
|
||||
ignore = float("-inf")
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, input_ids, decoder_input_ids
|
||||
)
|
||||
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
||||
expected_causal_mask = torch.tensor(
|
||||
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
|
||||
).to(input_ids.device)
|
||||
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
|
||||
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
|
||||
|
||||
def test_resize_tokens_embeddings_more(self):
|
||||
config, input_ids, _ = self._get_config_and_data()
|
||||
|
|
|
|||
Loading…
Reference in a new issue