From 008672e6e5fb0f2d2fc6fbd367ab6e135eea3f2d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 18 Mar 2021 13:12:04 -0400 Subject: [PATCH] Fix distributed evaluation (#10795) * Fix distributed evaluation * Use logger --- src/transformers/trainer.py | 11 ++++++++--- tests/test_trainer_distributed.py | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a809cb7fa..14aefba18 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -690,7 +690,7 @@ class Trainer: """ Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset. - Will raise an exception if the underlying dataset dese not implement method :obj:`__len__` + Will raise an exception if the underlying dataset does not implement method :obj:`__len__` """ return len(dataloader.dataset) @@ -1812,8 +1812,13 @@ class Trainer: eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) if not prediction_loss_only: - preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) - labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass + # a batch size to the sampler) + make_multiple_of = None + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): + make_multiple_of = dataloader.sampler.batch_size + preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) model.eval() diff --git a/tests/test_trainer_distributed.py b/tests/test_trainer_distributed.py index c0fbd3731..d6783a628 100644 --- a/tests/test_trainer_distributed.py +++ b/tests/test_trainer_distributed.py @@ -97,6 +97,11 @@ if __name__ == "__main__": def compute_metrics(p: EvalPrediction) -> Dict: sequential = list(range(len(dataset))) success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential + if not success and training_args.local_rank == 0: + logger.warning( + "Predictions and/or labels do not match expected results:\n - predictions: " + f"{p.predictions.tolist()}\n - labels: {p.label_ids.tolist()}\n - expected: {sequential}" + ) return {"success": success} trainer = Trainer(