mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
typing for remote_cache (#133446)
Summary: typing annotations for remote_cache Redo of #133299 with fixed annotations. Test Plan: unit tests Differential Revision: D61271883 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133446 Approved by: https://github.com/oulgen
This commit is contained in:
parent
7eb31e5023
commit
c88174df95
2 changed files with 16 additions and 10 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class RemoteCacheBackend:
|
||||
|
|
@ -12,11 +12,11 @@ class RemoteCacheBackend:
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str):
|
||||
def get(self, key: str) -> Optional[object]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, key: str, data: bytes):
|
||||
def put(self, key: str, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -37,8 +37,11 @@ class RedisRemoteCacheBackend(RemoteCacheBackend):
|
|||
def _get_key(self, key: str) -> str:
|
||||
return self._key_fmt.format(key=key)
|
||||
|
||||
def get(self, key: str):
|
||||
return self._redis.get(self._get_key(key))
|
||||
def get(self, key: str) -> Optional[bytes]:
|
||||
value = self._redis.get(self._get_key(key))
|
||||
# In theory redis.get() can return an Awaitable as well...
|
||||
assert value is None or isinstance(value, bytes)
|
||||
return value
|
||||
|
||||
def put(self, key: str, data: bytes):
|
||||
return self._redis.set(self._get_key(key), data)
|
||||
def put(self, key: str, data: bytes) -> None:
|
||||
self._redis.set(self._get_key(key), data)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import re
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -43,6 +43,9 @@ from .runtime_utils import (
|
|||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..remote_cache import RemoteCacheBackend
|
||||
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
|
|
@ -1079,7 +1082,7 @@ def cached_autotune(
|
|||
|
||||
local_cache = None
|
||||
cache_filename = None
|
||||
remote_cache = None
|
||||
remote_cache: Optional[RemoteCacheBackend] = None
|
||||
remote_cache_key = None
|
||||
best_config = None
|
||||
if not inductor_meta.get("force_disable_caches", False):
|
||||
|
|
@ -1146,7 +1149,7 @@ def cached_autotune(
|
|||
if local_cache is not None and cache_filename is not None:
|
||||
local_cache.put(cache_filename, data)
|
||||
if remote_cache is not None and remote_cache_key is not None:
|
||||
remote_cache.put(remote_cache_key, data)
|
||||
remote_cache.put(remote_cache_key, data) # type: ignore[arg-type]
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
type_str = "coordesc" if found_by_coordesc else "heuristic"
|
||||
|
|
|
|||
Loading…
Reference in a new issue