mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
OPT: Fix batched generation with FLAX (#21150)
* Fix Flax OPT numerical masking * re-enable test * add fix to bart and reintroduce copied from in opt
This commit is contained in:
parent
f4786d7f39
commit
e15f0d73db
8 changed files with 34 additions and 38 deletions
|
|
@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module):
|
|||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module):
|
|||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
|
|||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
|
|
|||
|
|
@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module):
|
|||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module):
|
|||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module):
|
|||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
|
|
|||
|
|
@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module):
|
|||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
|
|
|||
|
|
@ -364,43 +364,39 @@ class FlaxOPTGenerationTest(unittest.TestCase):
|
|||
|
||||
self.assertIsNotNone(output_string, EXPECTED_OUTPUTS)
|
||||
|
||||
# TODO fix in the following PR
|
||||
# def test_batch_generation(self):
|
||||
# model_id = "facebook/opt-350m"
|
||||
def test_batch_generation(self):
|
||||
model_id = "facebook/opt-350m"
|
||||
|
||||
# tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||||
# model = FlaxOPTForCausalLM.from_pretrained(model_id)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||||
model = FlaxOPTForCausalLM.from_pretrained(model_id)
|
||||
|
||||
# tokenizer.padding_side = "left"
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# # use different length sentences to test batching
|
||||
# sentences = [
|
||||
# "Hello, my dog is a little",
|
||||
# "Today, I",
|
||||
# ]
|
||||
# use different length sentences to test batching
|
||||
sentences = [
|
||||
"Hello, my dog is a little",
|
||||
"Today, I",
|
||||
]
|
||||
|
||||
# inputs = tokenizer(sentences, return_tensors="jax", padding=True)
|
||||
# input_ids = inputs["input_ids"]
|
||||
inputs = tokenizer(sentences, return_tensors="jax", padding=True)
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
# outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"], trace=False)
|
||||
outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"], trace=False)
|
||||
|
||||
# inputs_non_padded = tokenizer(sentences[0], return_tensors="jax").input_ids
|
||||
# output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="jax").input_ids
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
# num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].sum()
|
||||
# inputs_padded = tokenizer(sentences[1], return_tensors="jax").input_ids
|
||||
# output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].sum()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="jax").input_ids
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
# batch_out_sentence = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||
# non_padded_sentence = tokenizer.decode(output_non_padded[0][0], skip_special_tokens=True)
|
||||
# padded_sentence = tokenizer.decode(output_padded[0][0], skip_special_tokens=True)
|
||||
batch_out_sentence = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||
non_padded_sentence = tokenizer.decode(output_non_padded[0][0], skip_special_tokens=True)
|
||||
padded_sentence = tokenizer.decode(output_padded[0][0], skip_special_tokens=True)
|
||||
|
||||
# expected_output_sentence = [
|
||||
# "Hello, my dog is a little bit of a dork.\nI'm a little bit",
|
||||
# "Today, I<s><s><s><s><s><s><s><s><s><s><s><s>"
|
||||
# # TODO fix this test in next PR
|
||||
# # "Today, I was in the middle of a conversation with a friend about the",
|
||||
# ]
|
||||
# self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
# # TODO outputs will be similar, fix in next PR
|
||||
# self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
|
||||
expected_output_sentence = [
|
||||
"Hello, my dog is a little bit of a dork.\nI'm a little bit",
|
||||
"Today, I was in the middle of a conversation with a friend about the",
|
||||
]
|
||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
|
|
|||
Loading…
Reference in a new issue