mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[torch_int_div] Correct true division in generation (#15498)
* [torch_int_div] Correct true division in generation * up * up
This commit is contained in:
parent
5f1918a4a8
commit
c47d259241
11 changed files with 51 additions and 22 deletions
|
|
@ -1561,6 +1561,7 @@ if is_torch_available():
|
|||
"get_polynomial_decay_schedule_with_warmup",
|
||||
"get_scheduler",
|
||||
]
|
||||
_import_structure["pytorch_utils"] = []
|
||||
_import_structure["sagemaker"] = []
|
||||
_import_structure["trainer"] = ["Trainer"]
|
||||
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from .generation_stopping_criteria import (
|
|||
StoppingCriteriaList,
|
||||
validate_stopping_criteria,
|
||||
)
|
||||
from .pytorch_utils import torch_int_div
|
||||
from .utils import logging
|
||||
|
||||
|
||||
|
|
@ -2024,7 +2025,7 @@ class GenerationMixin:
|
|||
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = (next_tokens / vocab_size).long()
|
||||
next_indices = torch_int_div(next_tokens, vocab_size)
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
|
|
@ -2345,7 +2346,7 @@ class GenerationMixin:
|
|||
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
||||
next_tokens = torch.gather(next_tokens, -1, _indices)
|
||||
|
||||
next_indices = next_tokens // vocab_size
|
||||
next_indices = torch_int_div(next_tokens, vocab_size)
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
|
|
@ -2678,7 +2679,7 @@ class GenerationMixin:
|
|||
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = next_tokens // vocab_size
|
||||
next_indices = torch_int_div(next_tokens, vocab_size)
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
|
|
@ -2706,7 +2707,7 @@ class GenerationMixin:
|
|||
# (beam_idx // group_size) -> batch_idx
|
||||
# (beam_idx % group_size) -> offset of idx inside the group
|
||||
reordering_indices[batch_group_indices] = (
|
||||
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
|
||||
num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
|
||||
)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from functools import partial
|
|||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import Tensor, device, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
|
|
@ -2463,13 +2462,3 @@ def apply_chunking_to_forward(
|
|||
return torch.cat(output_chunks, dim=chunk_dim)
|
||||
|
||||
return forward_fn(*input_tensors)
|
||||
|
||||
|
||||
def torch_int_div(tensor1, tensor2):
|
||||
"""
|
||||
A function that performs integer division across different versions of PyTorch.
|
||||
"""
|
||||
if version.parse(torch.__version__) < version.parse("1.8.0"):
|
||||
return tensor1 // tensor2
|
||||
else:
|
||||
return torch.div(tensor1, tensor2, rounding_mode="floor")
|
||||
|
|
|
|||
|
|
@ -33,7 +33,8 @@ from ...file_utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, torch_int_div
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import logging
|
||||
from .configuration_hubert import HubertConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, torch_int_div
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import logging
|
||||
from .configuration_sew import SEWConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, torch_int_div
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import logging
|
||||
from .configuration_sew_d import SEWDConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ from ...file_utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, torch_int_div
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import logging
|
||||
from .configuration_unispeech import UniSpeechConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ from ...file_utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, torch_int_div
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import logging
|
||||
from .configuration_unispeech_sat import UniSpeechSatConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ from ...modeling_outputs import (
|
|||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, torch_int_div
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import logging
|
||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ from ...file_utils import (
|
|||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, torch_int_div
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import logging
|
||||
from .configuration_wavlm import WavLMConfig
|
||||
|
||||
|
|
|
|||
31
src/transformers/pytorch_utils.py
Normal file
31
src/transformers/pytorch_utils.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def torch_int_div(tensor1, tensor2):
|
||||
"""
|
||||
A function that performs integer division across different versions of PyTorch.
|
||||
"""
|
||||
if version.parse(torch.__version__) < version.parse("1.8.0"):
|
||||
return tensor1 // tensor2
|
||||
else:
|
||||
return torch.div(tensor1, tensor2, rounding_mode="floor")
|
||||
Loading…
Reference in a new issue