Update namespaces inside torch.utils.data to the latest. (#13167)

* Update torch.utils.data namespaces to the latest.

* Format

* Update Dataloader.

* Style
This commit is contained in:
Allan Lin 2021-08-19 20:29:51 +08:00 committed by GitHub
parent 1fec32adc6
commit 91ff480e26
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 41 additions and 44 deletions

View file

@ -77,7 +77,7 @@ class Split(Enum):
if is_torch_available():
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
class MultipleChoiceDataset(Dataset):
"""

View file

@ -141,7 +141,7 @@ class Seq2SeqTrainer(Trainer):
)
return scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():

View file

@ -206,7 +206,7 @@ class TokenClassificationTask:
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
class TokenClassificationDataset(Dataset):
"""

View file

@ -31,7 +31,7 @@ import random
import datasets
import torch
from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -31,7 +31,7 @@ import random
import datasets
import torch
from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -29,7 +29,7 @@ from typing import Optional, Union
import datasets
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -28,7 +28,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -28,7 +28,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -29,7 +29,7 @@ import nltk
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -21,7 +21,7 @@ import random
import datasets
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -27,7 +27,7 @@ import random
import datasets
import torch
from datasets import ClassLabel, load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -28,7 +28,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -88,7 +88,7 @@ class InputFeatures:
if is_torch_available():
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
class HansDataset(Dataset):
"""

View file

@ -19,7 +19,7 @@ import copy
from collections import defaultdict
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.data import BatchSampler, Sampler
from utils import logger

View file

@ -20,7 +20,7 @@ from enum import Enum
from typing import List, Optional, Union
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from filelock import FileLock

View file

@ -21,7 +21,7 @@ import warnings
from typing import Dict, List, Optional
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from filelock import FileLock

View file

@ -19,7 +19,7 @@ from enum import Enum
from typing import Dict, List, Optional, Union
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from filelock import FileLock

View file

@ -49,10 +49,8 @@ import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PretrainedConfig
@ -206,16 +204,16 @@ class Trainer:
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`):
train_dataset (:obj:`torch.utils.data.Dataset` or :obj:`torch.utils.data.IterableDataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
training in a distributed fashion, your iterable dataset should either use a internal attribute
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identical on all
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
:obj:`set_epoch()` method that internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
Note that if it's a :obj:`torch.utils.data.IterableDataset` with some randomization and you are training in
a distributed fashion, your iterable dataset should either use a internal attribute :obj:`generator` that
is a :obj:`torch.Generator` for the randomization that must be identical on all processes (and the Trainer
will manually set the seed of this :obj:`generator` at each epoch) or have a :obj:`set_epoch()` method that
internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
@ -537,7 +535,7 @@ class Trainer:
else:
return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
return None
@ -617,7 +615,7 @@ class Trainer:
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset):
if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
train_dataset,
@ -647,7 +645,7 @@ class Trainer:
pin_memory=self.args.dataloader_pin_memory,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code
if self.args.use_legacy_prediction_loop:
if is_torch_tpu_available():
@ -683,7 +681,7 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
@ -694,7 +692,7 @@ class Trainer:
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset):
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
eval_dataset = IterableDatasetShard(
eval_dataset,
@ -730,14 +728,14 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior.
Args:
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
test_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset):
if isinstance(test_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
test_dataset = IterableDatasetShard(
test_dataset,

View file

@ -175,9 +175,9 @@ class TrainerCallback:
The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
The scheduler used for setting the learning rate.
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
train_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training.
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
eval_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training.
metrics (:obj:`Dict[str, float]`):
The metrics computed by the last evaluation phase.

View file

@ -29,9 +29,8 @@ import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .tokenization_utils_base import BatchEncoding
@ -290,7 +289,7 @@ class SequentialDistributedSampler(Sampler):
return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
@ -690,7 +689,7 @@ class IterableDatasetShard(IterableDataset):
Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
dataset (:obj:`torch.utils.data.IterableDataset`):
The batch sampler to split in several shards.
batch_size (:obj:`int`, `optional`, defaults to 1):
The size of the batches per shard.

View file

@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from .deepspeed import is_deepspeed_zero3_enabled
from .trainer import Trainer

View file

@ -499,7 +499,7 @@ import random
import datasets
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers

View file

@ -31,7 +31,7 @@ logger = logging.get_logger(__name__)
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from transformers import Trainer

View file

@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from transformers import Trainer