mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Iterative generation using Input embeds and past_key_values (#35890)
* Iterative generation using input embeds
* ruff fix
* Added Testcase
* Updated comment
* ♻️ Refactored testcase
* Skip test for these models
* Continue generation using input embeds and cache
* Skip generate_continue_from_embeds test
* Refactor `prepare_input_for_generation` func
* Continue generation using input embeds and cache
* Modular changes fix
* Overwrite 'prepare_inputs_for_generation' function
This commit is contained in:
parent
b5f327f350
commit
7aee036e54
18 changed files with 276 additions and 34 deletions
|
|
@ -381,9 +381,13 @@ class GenerationMixin:
|
|||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
if past_key_values is not None:
|
||||
model_inputs["past_key_values"] = past_key_values
|
||||
if (
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
|
|
@ -393,9 +397,9 @@ class GenerationMixin:
|
|||
|
||||
# 3. Prepare base model inputs
|
||||
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
|
||||
if not self.config.is_encoder_decoder:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
||||
model_inputs[input_ids_key] = None
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -895,8 +895,12 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
|||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
|
|
@ -905,7 +909,7 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
|||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
|
||||
|
|
|
|||
|
|
@ -1654,8 +1654,12 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
|
|||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
|
|
@ -1668,10 +1672,13 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
|
|||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0:
|
||||
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
|
|
|||
|
|
@ -1674,10 +1674,13 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
|||
else:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None:
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
if input_ids.shape[1] == 0:
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
else:
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
if image_attention_mask is not None:
|
||||
|
|
@ -1687,14 +1690,19 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
|||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
|
||||
# If past_key_values are present then slice the postion ids for only only the unprocessed tokens.
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0:
|
||||
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
||||
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
|
|
|
|||
|
|
@ -1901,8 +1901,7 @@ class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin):
|
|||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, "
|
||||
"for speech-to-speech.",
|
||||
"The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, for speech-to-speech.",
|
||||
MOSHI_START_DOCSTRING,
|
||||
)
|
||||
class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
||||
|
|
@ -2458,18 +2457,57 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
|||
blank_user_audio_codes: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- Moshi has custom post-processing
|
||||
# 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
# Overwritten -- Moshi has custom post-processing on the prepared inputs.
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_cache_shape(),
|
||||
dtype=self.lm_head.weight.dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Now that everything is prepared, generate audio_codes using the depth decoder
|
||||
|
|
|
|||
|
|
@ -1261,7 +1261,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and attention_mask.device.type in ["cuda", "xpu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
|
|
@ -1872,8 +1872,12 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
|||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
|
|
@ -1886,7 +1890,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
|||
pixel_values_videos = None
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
|
||||
|
|
|
|||
|
|
@ -770,8 +770,12 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
|||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
|
|
@ -784,7 +788,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
|||
pixel_values_videos = None
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
|
||||
|
|
|
|||
|
|
@ -1735,8 +1735,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
||||
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||
elif (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
|
|
@ -1749,7 +1753,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
pixel_values_videos = None
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
|
||||
|
|
|
|||
|
|
@ -1557,7 +1557,7 @@ class Zamba2Model(Zamba2PreTrainedModel):
|
|||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and attention_mask.device.type in ["cuda", "xpu"]
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
|
|
|
|||
|
|
@ -1857,6 +1857,83 @@ class GenerationTesterMixin:
|
|||
)
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
"""Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call."""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
if "token_type_ids" in inputs_dict:
|
||||
del inputs_dict["token_type_ids"]
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder")
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
|
||||
|
||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs_dict)
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"]
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
model.config.is_decoder = True
|
||||
model.generation_config.use_cache = True
|
||||
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values.
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs = model.generate(inputs_embeds=input_embeds, max_new_tokens=4, **generation_kwargs)
|
||||
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens)
|
||||
initial_output = model.generate(inputs_embeds=input_embeds, max_new_tokens=3, **generation_kwargs)
|
||||
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
|
||||
cached_output = model.generate(
|
||||
inputs_embeds=continued_embeds,
|
||||
max_new_tokens=1,
|
||||
past_key_values=initial_output.past_key_values,
|
||||
**generation_kwargs,
|
||||
)
|
||||
|
||||
# Combine the (3 + 1) generated tokens and verify it matches with full generation.
|
||||
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
|
||||
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
|
||||
# The two sets of past kv should be equal to each other
|
||||
for layer_idx in range(len(cached_output.past_key_values)):
|
||||
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
cached_output.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||
@require_torch_gpu
|
||||
@pytest.mark.generate
|
||||
|
|
|
|||
|
|
@ -334,6 +334,10 @@ class ClvpDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@unittest.skip(reason="Clvp `prepare_inputs_for_generation` function doesn't have cache position.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
|
||||
class ClvpModelForConditionalGenerationTester:
|
||||
def __init__(self, parent, is_training=False):
|
||||
|
|
|
|||
|
|
@ -131,6 +131,10 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
|
|
|
|||
|
|
@ -325,6 +325,10 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip(reason="Fuyu `prepare_inputs_for_generation` function doesn't have cache position.")
|
||||
def test_generate_continue_from_inputs_embeds():
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
|
|
|||
|
|
@ -146,6 +146,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
|
|
|
|||
|
|
@ -450,6 +450,10 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BigCodeGPT has a non-standard KV cache format and breaks this test.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_gpt_bigcode_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)
|
||||
|
|
|
|||
|
|
@ -755,6 +755,65 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
|||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
"""Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`."""
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(inputs)
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
model.generation_config.use_cache = True
|
||||
|
||||
input_ids = inputs.pop("input_ids")
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
inputs["inputs_embeds"] = input_embeds
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, max_new_tokens=4, **generation_kwargs)
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||
initial_output = model.generate(**inputs, max_new_tokens=3, **generation_kwargs)
|
||||
inputs["past_key_values"] = initial_output.past_key_values
|
||||
|
||||
new_attention_len = input_ids.shape[1] + initial_output.sequences.shape[-1]
|
||||
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
|
||||
inputs["inputs_embeds"] = continued_embeds
|
||||
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["attention_mask"],
|
||||
(0, new_attention_len - inputs["attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
if "image_attention_mask" in inputs:
|
||||
inputs["image_attention_mask"] = inputs["image_attention_mask"][..., -1:, :]
|
||||
|
||||
cached_output = model.generate(**inputs, max_new_tokens=1, **generation_kwargs)
|
||||
|
||||
# Verify that the combined outputs match the full generation.
|
||||
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
|
||||
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
|
||||
for layer_idx in range(len(cached_output.past_key_values)):
|
||||
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
cached_output.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
|
|
|
|||
|
|
@ -358,6 +358,10 @@ class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@is_flaky(max_attempts=5, description="flaky on some models.")
|
||||
def test_save_load(self):
|
||||
super().test_save_load()
|
||||
|
|
@ -824,6 +828,7 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||
output_ids_generate = model.generate(
|
||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||
)
|
||||
print(output_ids_generate)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@unittest.skip(reason="The audio encoder has no gradients.")
|
||||
|
|
@ -919,6 +924,10 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@is_flaky(max_attempts=5, description="flaky on some models.")
|
||||
def test_save_load(self):
|
||||
super().test_save_load()
|
||||
|
|
|
|||
|
|
@ -333,6 +333,10 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Zamba2 has hybrid cache.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue