mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix synced multi-GPU generation with LLMs and VLMs (#35893)
* Fix synced multi-GPU generation * fix copies --------- Co-authored-by: Davit Manukyan <ManukyanD> Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
This commit is contained in:
parent
4831a94ee7
commit
d8080d55c7
13 changed files with 103 additions and 17 deletions
|
|
@ -40,7 +40,13 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.import_utils import (
|
||||
is_causal_conv1d_available,
|
||||
|
|
@ -1578,8 +1584,13 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if not empty_past_kv:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ from ...modeling_utils import PreTrainedModel
|
|||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -1268,8 +1269,13 @@ class BambaForCausalLM(LlamaForCausalLM):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if not empty_past_kv:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from ...modeling_outputs import (
|
|||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from .configuration_bloom import BloomConfig
|
||||
|
||||
|
||||
|
|
@ -893,8 +893,13 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from ...utils import (
|
|||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -1651,8 +1652,13 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from ...utils import (
|
|||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -921,8 +922,13 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from ...modeling_rope_utils import rope_config_validation
|
|||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ..cohere.modeling_cohere import (
|
||||
|
|
@ -590,8 +591,13 @@ class Cohere2ForCausalLM(CohereForCausalLM):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from ...modeling_utils import PreTrainedModel
|
|||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -1566,8 +1567,13 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if not empty_past_kv:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -2174,8 +2175,13 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from ...utils import (
|
|||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -1869,8 +1870,13 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ from ...feature_extraction_utils import BatchFeature
|
|||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_flash_attn_2_available
|
||||
from ...utils import is_flash_attn_2_available, is_torchdynamo_compiling
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
|
|
@ -768,8 +768,13 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from ...utils import (
|
|||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -1732,8 +1733,13 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
|||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
|
@ -1321,7 +1322,12 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
|
|
@ -37,7 +37,13 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast,
|
|||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
|
||||
from .configuration_zamba2 import Zamba2Config
|
||||
|
|
@ -1753,7 +1759,12 @@ class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin):
|
|||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
|
|
|||
Loading…
Reference in a new issue