typing for remote_cache (#133446)

Summary:
typing annotations for remote_cache
Redo of #133299 with fixed annotations.

Test Plan: unit tests

Differential Revision: D61271883

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133446
Approved by: https://github.com/oulgen
This commit is contained in:
Aaron Orenstein 2024-08-15 06:36:13 +00:00 committed by PyTorch MergeBot
parent 7eb31e5023
commit c88174df95
2 changed files with 16 additions and 10 deletions

View file

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import os
from abc import abstractmethod
from typing import Optional
class RemoteCacheBackend:
@ -12,11 +12,11 @@ class RemoteCacheBackend:
pass
@abstractmethod
def get(self, key: str):
def get(self, key: str) -> Optional[object]:
pass
@abstractmethod
def put(self, key: str, data: bytes):
def put(self, key: str, data: bytes) -> None:
pass
@ -37,8 +37,11 @@ class RedisRemoteCacheBackend(RemoteCacheBackend):
def _get_key(self, key: str) -> str:
return self._key_fmt.format(key=key)
def get(self, key: str):
return self._redis.get(self._get_key(key))
def get(self, key: str) -> Optional[bytes]:
value = self._redis.get(self._get_key(key))
# In theory redis.get() can return an Awaitable as well...
assert value is None or isinstance(value, bytes)
return value
def put(self, key: str, data: bytes):
return self._redis.set(self._get_key(key), data)
def put(self, key: str, data: bytes) -> None:
self._redis.set(self._get_key(key), data)

View file

@ -14,7 +14,7 @@ import re
import sys
import threading
import time
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import torch
@ -43,6 +43,9 @@ from .runtime_utils import (
)
if TYPE_CHECKING:
from ..remote_cache import RemoteCacheBackend
try:
import triton
except ImportError:
@ -1079,7 +1082,7 @@ def cached_autotune(
local_cache = None
cache_filename = None
remote_cache = None
remote_cache: Optional[RemoteCacheBackend] = None
remote_cache_key = None
best_config = None
if not inductor_meta.get("force_disable_caches", False):
@ -1146,7 +1149,7 @@ def cached_autotune(
if local_cache is not None and cache_filename is not None:
local_cache.put(cache_filename, data)
if remote_cache is not None and remote_cache_key is not None:
remote_cache.put(remote_cache_key, data)
remote_cache.put(remote_cache_key, data) # type: ignore[arg-type]
if log.isEnabledFor(logging.DEBUG):
type_str = "coordesc" if found_by_coordesc else "heuristic"