mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
remote_cache: Add a waitcounter for gets and sets (#141307)
This adds a basic waitcounter to help show if we're spending a lot of time doing gets and sets to remote caches Pull Request resolved: https://github.com/pytorch/pytorch/pull/141307 Approved by: https://github.com/masnesral
This commit is contained in:
parent
daa77f3d9f
commit
64d44a39a1
1 changed files with 21 additions and 18 deletions
|
|
@ -15,6 +15,7 @@ from typing_extensions import override, TypeAlias
|
|||
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor import config
|
||||
from torch.monitor import _WaitCounter
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -162,29 +163,31 @@ class RemoteCache(Generic[_T]):
|
|||
# See if the cache contains `key`. Returns `None` if the value is not
|
||||
# present in the cache.
|
||||
def get(self, key: str) -> Optional[_T]:
|
||||
sample = self._create_sample()
|
||||
try:
|
||||
result = self._get(key, sample)
|
||||
cache_stats.get(type(self).__name__, result)
|
||||
except Exception:
|
||||
cache_stats.exception(type(self).__name__)
|
||||
raise
|
||||
self._log_sample(sample)
|
||||
return result
|
||||
with _WaitCounter("pytorch.remote_cache.get").guard():
|
||||
sample = self._create_sample()
|
||||
try:
|
||||
result = self._get(key, sample)
|
||||
cache_stats.get(type(self).__name__, result)
|
||||
except Exception:
|
||||
cache_stats.exception(type(self).__name__)
|
||||
raise
|
||||
self._log_sample(sample)
|
||||
return result
|
||||
|
||||
# Add `value` to the cache with the key `key`. Note that `None` is not a
|
||||
# valid value even if _T supports it (because you can't tell the difference
|
||||
# between `None` and a missing cache entry).
|
||||
def put(self, key: str, value: _T) -> None:
|
||||
assert value is not None
|
||||
sample = self._create_sample()
|
||||
try:
|
||||
self._put(key, value, sample)
|
||||
cache_stats.put(type(self).__name__)
|
||||
except Exception:
|
||||
cache_stats.exception(type(self).__name__)
|
||||
raise
|
||||
self._log_sample(sample)
|
||||
with _WaitCounter("pytorch.remote_cache.put").guard():
|
||||
assert value is not None
|
||||
sample = self._create_sample()
|
||||
try:
|
||||
self._put(key, value, sample)
|
||||
cache_stats.put(type(self).__name__)
|
||||
except Exception:
|
||||
cache_stats.exception(type(self).__name__)
|
||||
raise
|
||||
self._log_sample(sample)
|
||||
|
||||
# Used to convert data from the cache into structured data.
|
||||
def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]
|
||||
|
|
|
|||
Loading…
Reference in a new issue