mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update namespaces inside torch.utils.data to the latest. (#13167)
* Update torch.utils.data namespaces to the latest. * Format * Update Dataloader. * Style
This commit is contained in:
parent
1fec32adc6
commit
91ff480e26
24 changed files with 41 additions and 44 deletions
|
|
@ -77,7 +77,7 @@ class Split(Enum):
|
|||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class MultipleChoiceDataset(Dataset):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class Seq2SeqTrainer(Trainer):
|
|||
)
|
||||
return scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ class TokenClassificationTask:
|
|||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class TokenClassificationDataset(Dataset):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import random
|
|||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import random
|
|||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from typing import Optional, Union
|
|||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ import datasets
|
|||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ import datasets
|
|||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ import nltk
|
|||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import random
|
|||
|
||||
import datasets
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import random
|
|||
import datasets
|
||||
import torch
|
||||
from datasets import ClassLabel, load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ import datasets
|
|||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class InputFeatures:
|
|||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class HansDataset(Dataset):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import copy
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import BatchSampler, Sampler
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
|
||||
from utils import logger
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from enum import Enum
|
|||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import warnings
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from enum import Enum
|
|||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
|
|
|
|||
|
|
@ -49,10 +49,8 @@ import numpy as np
|
|||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from . import __version__
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
|
@ -206,16 +204,16 @@ class Trainer:
|
|||
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
|
||||
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
|
||||
:func:`~transformers.DataCollatorWithPadding` otherwise.
|
||||
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`):
|
||||
train_dataset (:obj:`torch.utils.data.Dataset` or :obj:`torch.utils.data.IterableDataset`, `optional`):
|
||||
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
|
||||
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
|
||||
training in a distributed fashion, your iterable dataset should either use a internal attribute
|
||||
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identical on all
|
||||
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
|
||||
:obj:`set_epoch()` method that internally sets the seed of the RNGs used.
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
Note that if it's a :obj:`torch.utils.data.IterableDataset` with some randomization and you are training in
|
||||
a distributed fashion, your iterable dataset should either use a internal attribute :obj:`generator` that
|
||||
is a :obj:`torch.Generator` for the randomization that must be identical on all processes (and the Trainer
|
||||
will manually set the seed of this :obj:`generator` at each epoch) or have a :obj:`set_epoch()` method that
|
||||
internally sets the seed of the RNGs used.
|
||||
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
|
||||
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
||||
|
|
@ -537,7 +535,7 @@ class Trainer:
|
|||
else:
|
||||
return dataset.remove_columns(ignored_columns)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if not isinstance(self.train_dataset, collections.abc.Sized):
|
||||
return None
|
||||
|
||||
|
|
@ -617,7 +615,7 @@ class Trainer:
|
|||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||
|
||||
if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset):
|
||||
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
train_dataset = IterableDatasetShard(
|
||||
train_dataset,
|
||||
|
|
@ -647,7 +645,7 @@ class Trainer:
|
|||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
||||
# Deprecated code
|
||||
if self.args.use_legacy_prediction_loop:
|
||||
if is_torch_tpu_available():
|
||||
|
|
@ -683,7 +681,7 @@ class Trainer:
|
|||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
||||
Args:
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
|
||||
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
||||
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||
"""
|
||||
|
|
@ -694,7 +692,7 @@ class Trainer:
|
|||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
|
||||
|
||||
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset):
|
||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
eval_dataset = IterableDatasetShard(
|
||||
eval_dataset,
|
||||
|
|
@ -730,14 +728,14 @@ class Trainer:
|
|||
Subclass and override this method if you want to inject some custom behavior.
|
||||
|
||||
Args:
|
||||
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
test_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
|
||||
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||
"""
|
||||
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||
test_dataset = self._remove_unused_columns(test_dataset, description="test")
|
||||
|
||||
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset):
|
||||
if isinstance(test_dataset, torch.utils.data.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
test_dataset = IterableDatasetShard(
|
||||
test_dataset,
|
||||
|
|
|
|||
|
|
@ -175,9 +175,9 @@ class TrainerCallback:
|
|||
The optimizer used for the training steps.
|
||||
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
|
||||
The scheduler used for setting the learning rate.
|
||||
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
||||
train_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
|
||||
The current dataloader used for training.
|
||||
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
||||
eval_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
|
||||
The current dataloader used for training.
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics computed by the last evaluation phase.
|
||||
|
|
|
|||
|
|
@ -29,9 +29,8 @@ import numpy as np
|
|||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
|
|
@ -290,7 +289,7 @@ class SequentialDistributedSampler(Sampler):
|
|||
return self.num_samples
|
||||
|
||||
|
||||
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
|
||||
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
|
||||
if xm.xrt_world_size() <= 1:
|
||||
return RandomSampler(dataset)
|
||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
|
|
@ -690,7 +689,7 @@ class IterableDatasetShard(IterableDataset):
|
|||
|
||||
|
||||
Args:
|
||||
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
|
||||
dataset (:obj:`torch.utils.data.IterableDataset`):
|
||||
The batch sampler to split in several shards.
|
||||
batch_size (:obj:`int`, `optional`, defaults to 1):
|
||||
The size of the batches per shard.
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .deepspeed import is_deepspeed_zero3_enabled
|
||||
from .trainer import Trainer
|
||||
|
|
|
|||
|
|
@ -499,7 +499,7 @@ import random
|
|||
|
||||
import datasets
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ logger = logging.get_logger(__name__)
|
|||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
|
|||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue