From c1dba1111b44d1bebbe3c2e42ce08ed60bf85dff Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 28 Jul 2023 14:14:27 +0200 Subject: [PATCH] Add test when downloading from gated repo (#25039) --- tests/utils/test_hub_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/utils/test_hub_utils.py b/tests/utils/test_hub_utils.py index 540847d66..e2a3f351d 100644 --- a/tests/utils/test_hub_utils.py +++ b/tests/utils/test_hub_utils.py @@ -36,6 +36,9 @@ RANDOM_BERT = "hf-internal-testing/tiny-random-bert" CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert") FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6" +GATED_REPO = "hf-internal-testing/dummy-gated-model" +README_FILE = "README.md" + class GetFromCacheTests(unittest.TestCase): def test_cached_file(self): @@ -124,3 +127,13 @@ class GetFromCacheTests(unittest.TestCase): self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename)) self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) + + def test_get_file_gated_repo(self): + """Test download file from a gated repo fails with correct message when not authenticated.""" + with self.assertRaisesRegex(EnvironmentError, "You are trying to access a gated repo."): + cached_file(GATED_REPO, README_FILE, use_auth_token=False) + + def test_has_file_gated_repo(self): + """Test check file existence from a gated repo fails with correct message when not authenticated.""" + with self.assertRaisesRegex(EnvironmentError, "is a gated repository"): + has_file(GATED_REPO, README_FILE, use_auth_token=False)