remote_cache: Add a waitcounter for gets and sets (#141307)

This adds a basic waitcounter to help show if we're spending a lot of
time doing gets and sets to remote caches

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141307
Approved by: https://github.com/masnesral
This commit is contained in:
Colin L. Rice 2024-12-02 12:42:08 -07:00 committed by PyTorch MergeBot
parent daa77f3d9f
commit 64d44a39a1

View file

@ -15,6 +15,7 @@ from typing_extensions import override, TypeAlias
from torch._dynamo.utils import dynamo_timed
from torch._inductor import config
from torch.monitor import _WaitCounter
try:
@ -162,29 +163,31 @@ class RemoteCache(Generic[_T]):
# See if the cache contains `key`. Returns `None` if the value is not
# present in the cache.
def get(self, key: str) -> Optional[_T]:
sample = self._create_sample()
try:
result = self._get(key, sample)
cache_stats.get(type(self).__name__, result)
except Exception:
cache_stats.exception(type(self).__name__)
raise
self._log_sample(sample)
return result
with _WaitCounter("pytorch.remote_cache.get").guard():
sample = self._create_sample()
try:
result = self._get(key, sample)
cache_stats.get(type(self).__name__, result)
except Exception:
cache_stats.exception(type(self).__name__)
raise
self._log_sample(sample)
return result
# Add `value` to the cache with the key `key`. Note that `None` is not a
# valid value even if _T supports it (because you can't tell the difference
# between `None` and a missing cache entry).
def put(self, key: str, value: _T) -> None:
assert value is not None
sample = self._create_sample()
try:
self._put(key, value, sample)
cache_stats.put(type(self).__name__)
except Exception:
cache_stats.exception(type(self).__name__)
raise
self._log_sample(sample)
with _WaitCounter("pytorch.remote_cache.put").guard():
assert value is not None
sample = self._create_sample()
try:
self._put(key, value, sample)
cache_stats.put(type(self).__name__)
except Exception:
cache_stats.exception(type(self).__name__)
raise
self._log_sample(sample)
# Used to convert data from the cache into structured data.
def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]