From ef6741fe65c130ddb33c43ad2ba2b82f40ea7e90 Mon Sep 17 00:00:00 2001 From: Leandro von Werra Date: Wed, 21 Sep 2022 11:33:22 +0400 Subject: [PATCH] Fix GLUE MNLI when using `max_eval_samples` (#18722) --- examples/pytorch/text-classification/run_glue.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index f6ed83028..3f97fc3f5 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -549,7 +549,11 @@ def main(): eval_datasets = [eval_dataset] if data_args.task_name == "mnli": tasks.append("mnli-mm") - eval_datasets.append(raw_datasets["validation_mismatched"]) + valid_mm_dataset = raw_datasets["validation_mismatched"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples) + valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples)) + eval_datasets.append(valid_mm_dataset) combined = {} for eval_dataset, task in zip(eval_datasets, tasks):