From a0a3e2f4692e7eef31ab6303def799d8bdf8906e Mon Sep 17 00:00:00 2001 From: Noah Young Date: Wed, 10 Jul 2024 02:17:03 -0700 Subject: [PATCH] Fix file type checks in data splits for contrastive training example script (#31720) fix data split file type checks --- examples/pytorch/contrastive-image-text/run_clip.py | 6 +++--- examples/tensorflow/contrastive-image-text/run_clip.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/contrastive-image-text/run_clip.py b/examples/pytorch/contrastive-image-text/run_clip.py index fed4d0bf6..c4936410c 100644 --- a/examples/pytorch/contrastive-image-text/run_clip.py +++ b/examples/pytorch/contrastive-image-text/run_clip.py @@ -190,9 +190,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension == "json", "`validation_file` should be a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." dataset_name_mapping = { diff --git a/examples/tensorflow/contrastive-image-text/run_clip.py b/examples/tensorflow/contrastive-image-text/run_clip.py index 0644ab25b..ba83bbe56 100644 --- a/examples/tensorflow/contrastive-image-text/run_clip.py +++ b/examples/tensorflow/contrastive-image-text/run_clip.py @@ -196,9 +196,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension == "json", "`validation_file` should be a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." dataset_name_mapping = {