From 64d44a39a10a3e4b7abc85ffadd831ace3227460 Mon Sep 17 00:00:00 2001 From: "Colin L. Rice" Date: Mon, 2 Dec 2024 12:42:08 -0700 Subject: [PATCH] remote_cache: Add a waitcounter for gets and sets (#141307) This adds a basic waitcounter to help show if we're spending a lot of time doing gets and sets to remote caches Pull Request resolved: https://github.com/pytorch/pytorch/pull/141307 Approved by: https://github.com/masnesral --- torch/_inductor/remote_cache.py | 39 ++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 4e53b920a42..7dcfd5efbb4 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -15,6 +15,7 @@ from typing_extensions import override, TypeAlias from torch._dynamo.utils import dynamo_timed from torch._inductor import config +from torch.monitor import _WaitCounter try: @@ -162,29 +163,31 @@ class RemoteCache(Generic[_T]): # See if the cache contains `key`. Returns `None` if the value is not # present in the cache. def get(self, key: str) -> Optional[_T]: - sample = self._create_sample() - try: - result = self._get(key, sample) - cache_stats.get(type(self).__name__, result) - except Exception: - cache_stats.exception(type(self).__name__) - raise - self._log_sample(sample) - return result + with _WaitCounter("pytorch.remote_cache.get").guard(): + sample = self._create_sample() + try: + result = self._get(key, sample) + cache_stats.get(type(self).__name__, result) + except Exception: + cache_stats.exception(type(self).__name__) + raise + self._log_sample(sample) + return result # Add `value` to the cache with the key `key`. Note that `None` is not a # valid value even if _T supports it (because you can't tell the difference # between `None` and a missing cache entry). def put(self, key: str, value: _T) -> None: - assert value is not None - sample = self._create_sample() - try: - self._put(key, value, sample) - cache_stats.put(type(self).__name__) - except Exception: - cache_stats.exception(type(self).__name__) - raise - self._log_sample(sample) + with _WaitCounter("pytorch.remote_cache.put").guard(): + assert value is not None + sample = self._create_sample() + try: + self._put(key, value, sample) + cache_stats.put(type(self).__name__) + except Exception: + cache_stats.exception(type(self).__name__) + raise + self._log_sample(sample) # Used to convert data from the cache into structured data. def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]