From 3f5d8636aaa34c7a78213bf32cc30a946ff57b46 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 29 May 2024 18:24:22 -0700 Subject: [PATCH] [inductor] Copy RedisRemoteCacheBackend into pytorch (#127480) Summary: We need an implementation of RedisRemoteCacheBackend with the same API that we're using for FbMemcacheRemoteFxGraphCacheBackend. So we'll stop using the Triton implementation and adapt a version for use by inductor. I also renamed parameters and cache entries to match our cache terminology. Test Plan: Ran this command twice and inspected log output to ensure I got cache hits: ``` TORCH_LOGS=+torch._inductor.codecache TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=1 python benchmarks/dynamo/torchbench.py --performance --inductor --device cuda --training --amp --print-compilation-time --only dcgan ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127480 Approved by: https://github.com/oulgen --- mypy.ini | 7 +++ test/inductor/test_codecache.py | 2 +- test/inductor/test_max_autotune.py | 2 +- torch/_inductor/codecache.py | 14 +++--- torch/_inductor/remote_cache.py | 46 ++++++++++++++++++++ torch/_inductor/runtime/triton_heuristics.py | 4 +- 6 files changed, 65 insertions(+), 10 deletions(-) create mode 100644 torch/_inductor/remote_cache.py diff --git a/mypy.ini b/mypy.ini index 48bd363ef6d..7d51847da44 100644 --- a/mypy.ini +++ b/mypy.ini @@ -294,3 +294,10 @@ ignore_missing_imports = True [mypy-torch_xla.*] ignore_missing_imports = True + +# +# Third party dependencies that are optional. +# + +[mypy-redis] +ignore_missing_imports = True \ No newline at end of file diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index af12454df3c..a96d9aa67e8 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -197,7 +197,7 @@ class TestFxGraphCache(TestCase): cache_module = ( "triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" if config.is_fbcode() - else "triton.runtime.cache.RedisRemoteCacheBackend" + else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) with config.patch( diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index cef7d610ee4..eed927a9864 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -269,7 +269,7 @@ class TestMaxAutotune(TestCase): cache_module = ( "triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" if config.is_fbcode() - else "triton.runtime.cache.RedisRemoteCacheBackend" + else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) with config.patch( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 62c7252db49..c47f0175148 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -976,16 +976,16 @@ class FxGraphCache: if remote: cache_id = "fx-graph-v1" try: - import triton - if config.is_fbcode(): - remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend( - cache_id + from triton.runtime.fb_memcache import ( + FbMemcacheRemoteFxGraphCacheBackend, ) + + remote_cache = FbMemcacheRemoteFxGraphCacheBackend(cache_id) else: - remote_cache = triton.runtime.cache.RedisRemoteCacheBackend( - cache_id - ) + from torch._inductor.remote_cache import RedisRemoteCacheBackend + + remote_cache = RedisRemoteCacheBackend(cache_id) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py new file mode 100644 index 00000000000..7c40f603c4d --- /dev/null +++ b/torch/_inductor/remote_cache.py @@ -0,0 +1,46 @@ +import os +from abc import abstractmethod + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, cache_id: str): + pass + + @abstractmethod + def get(self, key: str): + pass + + @abstractmethod + def put(self, key: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + """ + A Redis implementation of a remote/distributed cache. + """ + + def __init__(self, cache_id: str): + import redis + + self._cache_id = cache_id + self._key_fmt = os.environ.get( + "TORCHINDUCTOR_REDIS_KEY_FORMAT", "pt2:{cache_id}:{key}" + ) + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, key: str) -> str: + return self._key_fmt.format(cache_id=self._cache_id, key=key) + + def get(self, key: str): + return self._redis.get(self._get_key(key)) + + def put(self, key: str, data: bytes): + return self._redis.set(self._get_key(key), data) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5a27f7a08cd..75584a60c0f 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1078,7 +1078,9 @@ def cached_autotune( key ) else: - remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key) + from torch._inductor.remote_cache import RedisRemoteCacheBackend + + remote_cache = RedisRemoteCacheBackend(key) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True)