diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index 144c04e0f..aab8022a8 100755 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -242,9 +242,12 @@ def main(): data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] + extension = data_args.validation_file.split(".")[-1] + datasets = load_dataset(extension, data_files=data_files, field="data") # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html.