[torch_int_div] Correct true division in generation (#15498)

* [torch_int_div] Correct true division in generation

* up

* up
This commit is contained in:
Patrick von Platen 2022-02-07 16:04:18 +01:00 committed by GitHub
parent 5f1918a4a8
commit c47d259241
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 51 additions and 22 deletions

View file

@ -1561,6 +1561,7 @@ if is_torch_available():
"get_polynomial_decay_schedule_with_warmup",
"get_scheduler",
]
_import_structure["pytorch_utils"] = []
_import_structure["sagemaker"] = []
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]

View file

@ -48,6 +48,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from .pytorch_utils import torch_int_div
from .utils import logging
@ -2024,7 +2025,7 @@ class GenerationMixin:
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = (next_tokens / vocab_size).long()
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
@ -2345,7 +2346,7 @@ class GenerationMixin:
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = next_tokens // vocab_size
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
@ -2678,7 +2679,7 @@ class GenerationMixin:
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = next_tokens // vocab_size
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
@ -2706,7 +2707,7 @@ class GenerationMixin:
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
)
# Store scores, attentions and hidden_states when required

View file

@ -23,7 +23,6 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
@ -2463,13 +2462,3 @@ def apply_chunking_to_forward(
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")

View file

@ -33,7 +33,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_hubert import HubertConfig

View file

@ -29,7 +29,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_sew import SEWConfig

View file

@ -30,7 +30,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_sew_d import SEWDConfig

View file

@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_unispeech import UniSpeechConfig

View file

@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_unispeech_sat import UniSpeechSatConfig

View file

@ -41,7 +41,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config

View file

@ -35,7 +35,8 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_wavlm import WavLMConfig

View file

@ -0,0 +1,31 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging import version
from .utils import logging
logger = logging.get_logger(__name__)
def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")