diff --git a/src/transformers/safetensors_conversion.py b/src/transformers/safetensors_conversion.py index a80927d61..5c0179350 100644 --- a/src/transformers/safetensors_conversion.py +++ b/src/transformers/safetensors_conversion.py @@ -1,5 +1,3 @@ -import json -import uuid from typing import Optional import requests @@ -26,37 +24,33 @@ def spawn_conversion(token: str, private: bool, model_id: str): logger.info("Attempting to convert .bin model on the fly to safetensors.") safetensors_convert_space_url = "https://safetensors-convert.hf.space" - sse_url = f"{safetensors_convert_space_url}/queue/join" - sse_data_url = f"{safetensors_convert_space_url}/queue/data" + sse_url = f"{safetensors_convert_space_url}/call/run" - # The `fn_index` is necessary to indicate to gradio that we will use the `run` method of the Space. - hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())} - - def start(_sse_connection, payload): + def start(_sse_connection): for line in _sse_connection.iter_lines(): line = line.decode() - if line.startswith("data:"): - resp = json.loads(line[5:]) - logger.debug(f"Safetensors conversion status: {resp['msg']}") - if resp["msg"] == "queue_full": - raise ValueError("Queue is full! Please try again.") - elif resp["msg"] == "send_data": - event_id = resp["event_id"] - response = requests.post( - sse_data_url, - stream=True, - params=hash_data, - json={"event_id": event_id, **payload, **hash_data}, - ) - response.raise_for_status() - elif resp["msg"] == "process_completed": - return + if line.startswith("event:"): + status = line[7:] + logger.debug(f"Safetensors conversion status: {status}") - with requests.get(sse_url, stream=True, params=hash_data) as sse_connection: - data = {"data": [model_id, private, token]} + if status == "complete": + return + elif status == "heartbeat": + logger.debug("Heartbeat") + else: + logger.debug(f"Unknown status {status}") + else: + logger.debug(line) + + data = {"data": [model_id, private, token]} + + result = requests.post(sse_url, stream=True, json=data).json() + event_id = result["event_id"] + + with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection: try: logger.debug("Spawning safetensors automatic conversion.") - start(sse_connection, data) + start(sse_connection) except Exception as e: logger.warning(f"Error during conversion: {repr(e)}") @@ -86,7 +80,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs): try: - api = HfApi(token=cached_file_kwargs.get("token"), headers=http_user_agent()) + api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()}) sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) if sha is None: diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 3317a47d7..8af47cde8 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2009,19 +2009,18 @@ class ModelOnTheFlyConversionTester(unittest.TestCase): if thread.name == "Thread-autoconversion": thread.join(timeout=10) - with self.subTest("PR was open with the safetensors account"): - discussions = self.api.get_repo_discussions(self.repo_name) + discussions = self.api.get_repo_discussions(self.repo_name) - bot_opened_pr = None - bot_opened_pr_title = None + bot_opened_pr = None + bot_opened_pr_title = None - for discussion in discussions: - if discussion.author == "SFconvertbot": - bot_opened_pr = True - bot_opened_pr_title = discussion.title + for discussion in discussions: + if discussion.author == "SFconvertbot": + bot_opened_pr = True + bot_opened_pr_title = discussion.title - self.assertTrue(bot_opened_pr) - self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") + self.assertTrue(bot_opened_pr) + self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") @mock.patch("transformers.safetensors_conversion.spawn_conversion") def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock):