diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 15e99bdb0..c13fffd17 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -3,7 +3,6 @@ import glob import logging import os import time -import warnings from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple @@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer): default_val_metric = "rouge2" def __init__(self, hparams, **kwargs): + if hparams.sortish_sampler and hparams.gpus > 1: + hparams.replace_sampler_ddp = False super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) use_task_specific_params(self.model, "summarization") save_git_info(self.hparams.output_dir) @@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer): "val": self.hparams.val_max_target_length, "test": self.hparams.test_max_target_length, } - if self.hparams.sortish_sampler and self.hparams.gpus > 1: - self.hparams.sortish_sampler = False - warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs") assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" @@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer): dataset = self.get_dataset(type_path) sampler = None if self.hparams.sortish_sampler and type_path == "train": - assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__ - sampler = dataset.make_sortish_sampler(batch_size) + sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1) shuffle = False dataloader = DataLoader( diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 7acbbd7b5..26fb20bde 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -149,9 +149,9 @@ class TestSummarizationDistiller(unittest.TestCase): no_teacher=True, freeze_encoder=True, gpus=2, - sortish_sampler=False, + sortish_sampler=True, ) - self._test_distiller_cli(updates) + self._test_distiller_cli(updates, check_contents=False) def test_distill_no_teacher(self): updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 604cc6907..25d64f727 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -1,6 +1,7 @@ import itertools import json import linecache +import math import os import pickle from logging import getLogger @@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union import git import numpy as np import torch +import torch.distributed as dist from rouge_score import rouge_scorer, scoring from sacrebleu import corpus_bleu from torch import nn @@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset): def get_char_lens(data_file): return [len(x) for x in Path(data_file).open().readlines()] - def make_sortish_sampler(self, batch_size): - return SortishSampler(self.src_lens, batch_size) + def make_sortish_sampler(self, batch_size, distributed=False): + if distributed: + return DistributedSortishSampler(self, batch_size) + else: + return SortishSampler(self.src_lens, batch_size) def __getitem__(self, item): raise NotImplementedError("You must implement this") @@ -191,24 +196,77 @@ class SortishSampler(Sampler): def __init__(self, data, batch_size): self.data, self.bs = data, batch_size - def key(self, i): - return self.data[i] - def __len__(self) -> int: return len(self.data) def __iter__(self): - idxs = np.random.permutation(len(self.data)) - sz = self.bs * 50 - ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] - sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) - sz = self.bs - ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] - max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, - ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. - sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) - sort_idx = np.concatenate((ck_idx[0], sort_idx)) - return iter(sort_idx) + return iter(sortish_sampler_indices(self.data, self.bs)) + + +def sortish_sampler_indices(data: List, bs: int) -> np.array: + "Go through the text data by order of src length with a bit of randomness. From fastai repo." + + def key_fn(i): + return data[i] + + idxs = np.random.permutation(len(data)) + sz = bs * 50 + ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] + sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx]) + sz = bs + ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] + max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, + ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. + sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) + sort_idx = np.concatenate((ck_idx[0], sort_idx)) + return sort_idx + + +class DistributedSortishSampler(Sampler): + """Copied from torch DistributedSampler""" + + def __init__(self, dataset, batch_size, num_replicas=None, rank=None): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.batch_size = batch_size + + def __iter__(self) -> Iterable: + g = torch.Generator() + g.manual_seed(self.epoch) + available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas] + + sortish_data = [self.dataset.src_lens[i] for i in available_indices] + sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size) + indices = [available_indices[i] for i in sortish_indices] + assert len(indices) == self.num_samples + return iter(indices) + + def get_indices_for_rank(self) -> np.array: + indices = list(range(len(self.dataset))) + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + # subsample + available_indices = indices[self.rank : self.total_size : self.num_replicas] + return available_indices + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch logger = getLogger(__name__)