diff --git a/examples/pytorch/contrastive-image-text/run_clip.py b/examples/pytorch/contrastive-image-text/run_clip.py index fed4d0bf6..c4936410c 100644 --- a/examples/pytorch/contrastive-image-text/run_clip.py +++ b/examples/pytorch/contrastive-image-text/run_clip.py @@ -190,9 +190,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension == "json", "`validation_file` should be a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." dataset_name_mapping = { diff --git a/examples/tensorflow/contrastive-image-text/run_clip.py b/examples/tensorflow/contrastive-image-text/run_clip.py index 0644ab25b..ba83bbe56 100644 --- a/examples/tensorflow/contrastive-image-text/run_clip.py +++ b/examples/tensorflow/contrastive-image-text/run_clip.py @@ -196,9 +196,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension == "json", "`validation_file` should be a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." dataset_name_mapping = {