[inductor] Add API to make post_grad_custom passes cache-able (#137298)

Summary: See https://github.com/pytorch/pytorch/issues/130772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137298
Approved by: https://github.com/oulgen, https://github.com/eellison
This commit is contained in:
Sam Larsen 2024-10-07 10:41:50 -07:00 committed by PyTorch MergeBot
parent 8aa110cb00
commit 319eda9dfd
5 changed files with 167 additions and 18 deletions

View file

@ -2,8 +2,9 @@
import functools
import os
import pickle
import tempfile
import unittest
from typing import List
from typing import List, Optional, Union
from unittest import mock
import torch
@ -20,6 +21,7 @@ from torch._inductor.codecache import (
TensorMetadata,
TensorMetadataAndValues,
)
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
from torch._inductor.graph import GraphLowering
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase
@ -804,6 +806,61 @@ class TestFxGraphCacheHashing(TestCase):
FxGraphCachePickler.dumps(details3),
)
def test_hash_custom_passes(self):
"""
Test CustomGraphPass usage.
"""
class TestCustomGraphPass(CustomGraphPass):
def __init__(self):
self._uuid = None
def __call__(self, graph: torch.fx.graph.Graph) -> None:
return None
def uuid(self) -> Optional[Union[bytes, str]]:
return self._uuid
custom_pass = TestCustomGraphPass()
with config.patch({"post_grad_custom_pre_pass": custom_pass}):
custom_pass._uuid = "1"
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
custom_pass._uuid = "2"
details3 = FxGraphHashDetails(None, [], {}, [])
self.assertEqual(
FxGraphCachePickler.dumps(details1),
FxGraphCachePickler.dumps(details2),
)
self.assertNotEqual(
FxGraphCachePickler.dumps(details1),
FxGraphCachePickler.dumps(details3),
)
def test_get_hash_for_files(self):
"""
Test the get_hash_for_files helper.
"""
with tempfile.NamedTemporaryFile(delete=True) as temp:
temp.write(b"contents")
temp.flush()
hash1 = get_hash_for_files((temp.name,))
get_hash_for_files.cache_clear()
hash2 = get_hash_for_files((temp.name,))
temp.write(b" ")
temp.flush()
get_hash_for_files.cache_clear()
hash3 = get_hash_for_files((temp.name,))
self.assertEqual(hash1, hash2)
self.assertNotEqual(hash1, hash3)
class TestCudaCompileCommand(TestCase):
@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_cuda_compile_command(self):

View file

@ -8,6 +8,7 @@ import torch._inductor.pattern_matcher as pattern_matcher
import torch.fx as fx
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
from torch._inductor.lowering import lowerings as L
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
from torch._inductor.test_case import run_tests, TestCase
@ -107,13 +108,16 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
_register_fusion_lowering(_mkldnn_conv_relu_pattern(), custom_pass_dict)
# custom post grad pass
class _CustomPass(PatternMatcherPass):
class _CustomPass(PatternMatcherPass, CustomGraphPass):
def __init__(self) -> None:
super().__init__()
def __call__(self, g: torch.fx.graph.Graph):
self.apply(g)
def uuid(self) -> bytes:
return get_hash_for_files((__file__,))
# case model
class _ConvReLU(torch.nn.Module):
def __init__(self, ic, oc):

View file

@ -65,6 +65,7 @@ from torch._inductor.codegen.rocm.compile_command import (
rocm_compile_command,
rocm_compiler,
)
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
from torch._utils_internal import log_cache_bypass
from .remote_cache import create_cache
@ -605,8 +606,7 @@ class FxGraphCachePickler(pickle.Pickler):
try:
pickler.dump(obj)
except (TypeError, AttributeError) as e:
# Some configs options are callables, e.g., post_grad_custom_pre_pass,
# and may not pickle.
# Some configs options may not pickle.
log.warning("Can't pickle", exc_info=True)
raise BypassFxGraphCache("Config options may be unpickleable") from e
return stream.getvalue()
@ -778,6 +778,22 @@ class FxGraphHashDetails:
self.system_info = CacheBase.get_system()
self.inductor_config = config.save_config_portable()
# Custom post grad passes should provide an ID to hash.
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
config.post_grad_custom_pre_pass
)
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
config.post_grad_custom_post_pass
)
def _get_custom_pass_detail(
self, custom_pass: CustomGraphPassType
) -> Optional[Any]:
if not custom_pass:
return None
assert isinstance(custom_pass, CustomGraphPass)
return custom_pass.uuid()
def debug_lines(self) -> List[str]:
"""
Get a printable string describing in more detail all the attributes
@ -1257,6 +1273,12 @@ class FxGraphCache:
Check some conditions that would preclude caching and raise BypassFxGraphCache
to bypass in case caching is not possible.
"""
# Post grad custom passes must implement the CustomGraphPass or we don't
# know how to include them in the cache key calculation.
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
raise BypassFxGraphCache("Unsupported post grad custom pass")
# Freezing can embed constants that wouldn't be static across runs.
if config.freezing or config.aot_inductor.use_runtime_constant_folding:
raise BypassFxGraphCache(

View file

@ -3,6 +3,7 @@ import sys
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
import torch
import torch._inductor.custom_graph_pass
from torch._environment import is_fbcode
@ -131,18 +132,10 @@ b2b_gemm_pass = False
# register custom graph optimization pass hook. so far, pre/post passes are
# only applied before/after pattern_matcher in post_grad_passes.
#
# def my_custom_pre_pass(graph: torch.fx.graph.Graph):
# # my custom graph optimization pass
# ...
#
# def my_custom_post_pass(graph: torch.fx.graph.Graph):
# # my custom graph optimization pass
# ...
#
# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass
# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass
post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
# Implement CustomGraphPass to allow Inductor to graph compiled artifacts
# to which your custom passes have been applied:
post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
# Registers a custom joint graph pass.
joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
@ -1252,8 +1245,6 @@ class trace:
_save_config_ignore = [
# workaround: "Can't pickle <function ...>"
"trace.upload_tar",
"post_grad_custom_post_pass",
"post_grad_custom_pre_pass",
"joint_custom_pre_pass",
"joint_custom_post_pass",
"pre_grad_custom_pass",
@ -1267,6 +1258,9 @@ _cache_config_ignore_prefix = [
# not relevant
"worker_start_method",
"compile_threads",
# see CustomGraphPass; these are handled specially
"post_grad_custom_post_pass",
"post_grad_custom_pre_pass",
]
# External callable for matmul tuning candidates

View file

@ -0,0 +1,72 @@
import hashlib
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Callable, Optional, Tuple, Union
from typing_extensions import TypeAlias
import torch.fx.graph
class CustomGraphPass(ABC):
"""
Implement this interface for custom Graph passes:
1) The __call__() method contains the implementation of the custom pass.
2) The uuid() method enables inductor to cache compiled graphs when your custom
passes are applied. This method can return any identifier as long as it uniquely
identifies your implementation (and can be pickled). The caching logic includes this
identifier in its key calculation, i.e., any new value will effectively invalidate
existing entries. We expect custom passes would typically depend purely on the
textual reprensentation of the implementation. In that case, we recommend using the
'get_hash_for_files' helper below to compute a unique hash from the contents of a
static list of source files, i.e., the source(s) containing the custom pass
implementation. That approach ensures that any change to the implementation will
mean a new uuid.
** IMPORTANT ** If your custom pass's behavior depends on some external state, then
you'll need to implement something more complicated (or disable caching).
EXAMPLE:
class MyCustomGraphPass(CustomGraphPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
# my custom graph optimization pass
# ...
def uuid(self) -> Optional[Any]:
return get_hash_for_files((__file__,))
"""
@abstractmethod
def __call__(self, graph: torch.fx.graph.Graph) -> None:
"""
Implementation of the custom pass.
"""
@abstractmethod
def uuid(self) -> Optional[Any]:
"""
Return an ID to uniquely identify your custom pass implementation. Return None
to skip inductor code caching entirely.
"""
CustomGraphPassType: TypeAlias = Optional[
Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]]
]
@lru_cache(1)
def get_hash_for_files(paths: Tuple[str], extra: str = "") -> bytes:
"""
Helper to compute a unique string by hashing the contents of a list of files.
"""
hasher = hashlib.sha256()
hasher.update(extra.encode("utf-8"))
for path in paths:
with open(path, "rb") as f:
hasher.update(path.encode("utf-8"))
hasher.update(f.read())
return hasher.digest()