diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index f6ed83028..3f97fc3f5 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -549,7 +549,11 @@ def main(): eval_datasets = [eval_dataset] if data_args.task_name == "mnli": tasks.append("mnli-mm") - eval_datasets.append(raw_datasets["validation_mismatched"]) + valid_mm_dataset = raw_datasets["validation_mismatched"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples) + valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples)) + eval_datasets.append(valid_mm_dataset) combined = {} for eval_dataset, task in zip(eval_datasets, tasks):