From c88174df9570ef0c2829b29f50609b44fa2cda06 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Thu, 15 Aug 2024 06:36:13 +0000 Subject: [PATCH] typing for remote_cache (#133446) Summary: typing annotations for remote_cache Redo of #133299 with fixed annotations. Test Plan: unit tests Differential Revision: D61271883 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133446 Approved by: https://github.com/oulgen --- torch/_inductor/remote_cache.py | 17 ++++++++++------- torch/_inductor/runtime/triton_heuristics.py | 9 ++++++--- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 5f4c934edcf..391192eaac6 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -1,6 +1,6 @@ -# mypy: allow-untyped-defs import os from abc import abstractmethod +from typing import Optional class RemoteCacheBackend: @@ -12,11 +12,11 @@ class RemoteCacheBackend: pass @abstractmethod - def get(self, key: str): + def get(self, key: str) -> Optional[object]: pass @abstractmethod - def put(self, key: str, data: bytes): + def put(self, key: str, data: bytes) -> None: pass @@ -37,8 +37,11 @@ class RedisRemoteCacheBackend(RemoteCacheBackend): def _get_key(self, key: str) -> str: return self._key_fmt.format(key=key) - def get(self, key: str): - return self._redis.get(self._get_key(key)) + def get(self, key: str) -> Optional[bytes]: + value = self._redis.get(self._get_key(key)) + # In theory redis.get() can return an Awaitable as well... + assert value is None or isinstance(value, bytes) + return value - def put(self, key: str, data: bytes): - return self._redis.set(self._get_key(key), data) + def put(self, key: str, data: bytes) -> None: + self._redis.set(self._get_key(key), data) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c439260de5b..253803fb247 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -14,7 +14,7 @@ import re import sys import threading import time -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING import torch @@ -43,6 +43,9 @@ from .runtime_utils import ( ) +if TYPE_CHECKING: + from ..remote_cache import RemoteCacheBackend + try: import triton except ImportError: @@ -1079,7 +1082,7 @@ def cached_autotune( local_cache = None cache_filename = None - remote_cache = None + remote_cache: Optional[RemoteCacheBackend] = None remote_cache_key = None best_config = None if not inductor_meta.get("force_disable_caches", False): @@ -1146,7 +1149,7 @@ def cached_autotune( if local_cache is not None and cache_filename is not None: local_cache.put(cache_filename, data) if remote_cache is not None and remote_cache_key is not None: - remote_cache.put(remote_cache_key, data) + remote_cache.put(remote_cache_key, data) # type: ignore[arg-type] if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic"