mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[inductor] Add API to make post_grad_custom passes cache-able (#137298)
Summary: See https://github.com/pytorch/pytorch/issues/130772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137298 Approved by: https://github.com/oulgen, https://github.com/eellison
This commit is contained in:
parent
8aa110cb00
commit
319eda9dfd
5 changed files with 167 additions and 18 deletions
|
|
@ -2,8 +2,9 @@
|
|||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
from typing import List, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
|
@ -20,6 +21,7 @@ from torch._inductor.codecache import (
|
|||
TensorMetadata,
|
||||
TensorMetadataAndValues,
|
||||
)
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
|
@ -804,6 +806,61 @@ class TestFxGraphCacheHashing(TestCase):
|
|||
FxGraphCachePickler.dumps(details3),
|
||||
)
|
||||
|
||||
def test_hash_custom_passes(self):
|
||||
"""
|
||||
Test CustomGraphPass usage.
|
||||
"""
|
||||
|
||||
class TestCustomGraphPass(CustomGraphPass):
|
||||
def __init__(self):
|
||||
self._uuid = None
|
||||
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
return None
|
||||
|
||||
def uuid(self) -> Optional[Union[bytes, str]]:
|
||||
return self._uuid
|
||||
|
||||
custom_pass = TestCustomGraphPass()
|
||||
with config.patch({"post_grad_custom_pre_pass": custom_pass}):
|
||||
custom_pass._uuid = "1"
|
||||
details1 = FxGraphHashDetails(None, [], {}, [])
|
||||
details2 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
custom_pass._uuid = "2"
|
||||
details3 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
self.assertEqual(
|
||||
FxGraphCachePickler.dumps(details1),
|
||||
FxGraphCachePickler.dumps(details2),
|
||||
)
|
||||
self.assertNotEqual(
|
||||
FxGraphCachePickler.dumps(details1),
|
||||
FxGraphCachePickler.dumps(details3),
|
||||
)
|
||||
|
||||
def test_get_hash_for_files(self):
|
||||
"""
|
||||
Test the get_hash_for_files helper.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp:
|
||||
temp.write(b"contents")
|
||||
temp.flush()
|
||||
|
||||
hash1 = get_hash_for_files((temp.name,))
|
||||
get_hash_for_files.cache_clear()
|
||||
hash2 = get_hash_for_files((temp.name,))
|
||||
|
||||
temp.write(b" ")
|
||||
temp.flush()
|
||||
get_hash_for_files.cache_clear()
|
||||
hash3 = get_hash_for_files((temp.name,))
|
||||
|
||||
self.assertEqual(hash1, hash2)
|
||||
self.assertNotEqual(hash1, hash3)
|
||||
|
||||
|
||||
class TestCudaCompileCommand(TestCase):
|
||||
@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
def test_cuda_compile_command(self):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import torch._inductor.pattern_matcher as pattern_matcher
|
|||
import torch.fx as fx
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
|
||||
from torch._inductor.lowering import lowerings as L
|
||||
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
|
@ -107,13 +108,16 @@ class TestPostGradCustomPrePostPass(TestCustomPassBase):
|
|||
_register_fusion_lowering(_mkldnn_conv_relu_pattern(), custom_pass_dict)
|
||||
|
||||
# custom post grad pass
|
||||
class _CustomPass(PatternMatcherPass):
|
||||
class _CustomPass(PatternMatcherPass, CustomGraphPass):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, g: torch.fx.graph.Graph):
|
||||
self.apply(g)
|
||||
|
||||
def uuid(self) -> bytes:
|
||||
return get_hash_for_files((__file__,))
|
||||
|
||||
# case model
|
||||
class _ConvReLU(torch.nn.Module):
|
||||
def __init__(self, ic, oc):
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ from torch._inductor.codegen.rocm.compile_command import (
|
|||
rocm_compile_command,
|
||||
rocm_compiler,
|
||||
)
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
|
||||
from torch._utils_internal import log_cache_bypass
|
||||
|
||||
from .remote_cache import create_cache
|
||||
|
|
@ -605,8 +606,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
|||
try:
|
||||
pickler.dump(obj)
|
||||
except (TypeError, AttributeError) as e:
|
||||
# Some configs options are callables, e.g., post_grad_custom_pre_pass,
|
||||
# and may not pickle.
|
||||
# Some configs options may not pickle.
|
||||
log.warning("Can't pickle", exc_info=True)
|
||||
raise BypassFxGraphCache("Config options may be unpickleable") from e
|
||||
return stream.getvalue()
|
||||
|
|
@ -778,6 +778,22 @@ class FxGraphHashDetails:
|
|||
self.system_info = CacheBase.get_system()
|
||||
self.inductor_config = config.save_config_portable()
|
||||
|
||||
# Custom post grad passes should provide an ID to hash.
|
||||
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
|
||||
config.post_grad_custom_pre_pass
|
||||
)
|
||||
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
|
||||
config.post_grad_custom_post_pass
|
||||
)
|
||||
|
||||
def _get_custom_pass_detail(
|
||||
self, custom_pass: CustomGraphPassType
|
||||
) -> Optional[Any]:
|
||||
if not custom_pass:
|
||||
return None
|
||||
assert isinstance(custom_pass, CustomGraphPass)
|
||||
return custom_pass.uuid()
|
||||
|
||||
def debug_lines(self) -> List[str]:
|
||||
"""
|
||||
Get a printable string describing in more detail all the attributes
|
||||
|
|
@ -1257,6 +1273,12 @@ class FxGraphCache:
|
|||
Check some conditions that would preclude caching and raise BypassFxGraphCache
|
||||
to bypass in case caching is not possible.
|
||||
"""
|
||||
# Post grad custom passes must implement the CustomGraphPass or we don't
|
||||
# know how to include them in the cache key calculation.
|
||||
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
|
||||
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
|
||||
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
||||
|
||||
# Freezing can embed constants that wouldn't be static across runs.
|
||||
if config.freezing or config.aot_inductor.use_runtime_constant_folding:
|
||||
raise BypassFxGraphCache(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import sys
|
|||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.custom_graph_pass
|
||||
from torch._environment import is_fbcode
|
||||
|
||||
|
||||
|
|
@ -131,18 +132,10 @@ b2b_gemm_pass = False
|
|||
# register custom graph optimization pass hook. so far, pre/post passes are
|
||||
# only applied before/after pattern_matcher in post_grad_passes.
|
||||
#
|
||||
# def my_custom_pre_pass(graph: torch.fx.graph.Graph):
|
||||
# # my custom graph optimization pass
|
||||
# ...
|
||||
#
|
||||
# def my_custom_post_pass(graph: torch.fx.graph.Graph):
|
||||
# # my custom graph optimization pass
|
||||
# ...
|
||||
#
|
||||
# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass
|
||||
# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass
|
||||
post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
||||
post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
||||
# Implement CustomGraphPass to allow Inductor to graph compiled artifacts
|
||||
# to which your custom passes have been applied:
|
||||
post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
|
||||
# Registers a custom joint graph pass.
|
||||
joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
|
||||
|
|
@ -1252,8 +1245,6 @@ class trace:
|
|||
_save_config_ignore = [
|
||||
# workaround: "Can't pickle <function ...>"
|
||||
"trace.upload_tar",
|
||||
"post_grad_custom_post_pass",
|
||||
"post_grad_custom_pre_pass",
|
||||
"joint_custom_pre_pass",
|
||||
"joint_custom_post_pass",
|
||||
"pre_grad_custom_pass",
|
||||
|
|
@ -1267,6 +1258,9 @@ _cache_config_ignore_prefix = [
|
|||
# not relevant
|
||||
"worker_start_method",
|
||||
"compile_threads",
|
||||
# see CustomGraphPass; these are handled specially
|
||||
"post_grad_custom_post_pass",
|
||||
"post_grad_custom_pre_pass",
|
||||
]
|
||||
|
||||
# External callable for matmul tuning candidates
|
||||
|
|
|
|||
72
torch/_inductor/custom_graph_pass.py
Normal file
72
torch/_inductor/custom_graph_pass.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch.fx.graph
|
||||
|
||||
|
||||
class CustomGraphPass(ABC):
|
||||
"""
|
||||
Implement this interface for custom Graph passes:
|
||||
|
||||
1) The __call__() method contains the implementation of the custom pass.
|
||||
|
||||
2) The uuid() method enables inductor to cache compiled graphs when your custom
|
||||
passes are applied. This method can return any identifier as long as it uniquely
|
||||
identifies your implementation (and can be pickled). The caching logic includes this
|
||||
identifier in its key calculation, i.e., any new value will effectively invalidate
|
||||
existing entries. We expect custom passes would typically depend purely on the
|
||||
textual reprensentation of the implementation. In that case, we recommend using the
|
||||
'get_hash_for_files' helper below to compute a unique hash from the contents of a
|
||||
static list of source files, i.e., the source(s) containing the custom pass
|
||||
implementation. That approach ensures that any change to the implementation will
|
||||
mean a new uuid.
|
||||
|
||||
** IMPORTANT ** If your custom pass's behavior depends on some external state, then
|
||||
you'll need to implement something more complicated (or disable caching).
|
||||
|
||||
EXAMPLE:
|
||||
|
||||
class MyCustomGraphPass(CustomGraphPass):
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
# my custom graph optimization pass
|
||||
# ...
|
||||
|
||||
def uuid(self) -> Optional[Any]:
|
||||
return get_hash_for_files((__file__,))
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
"""
|
||||
Implementation of the custom pass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def uuid(self) -> Optional[Any]:
|
||||
"""
|
||||
Return an ID to uniquely identify your custom pass implementation. Return None
|
||||
to skip inductor code caching entirely.
|
||||
"""
|
||||
|
||||
|
||||
CustomGraphPassType: TypeAlias = Optional[
|
||||
Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]]
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_hash_for_files(paths: Tuple[str], extra: str = "") -> bytes:
|
||||
"""
|
||||
Helper to compute a unique string by hashing the contents of a list of files.
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(extra.encode("utf-8"))
|
||||
for path in paths:
|
||||
with open(path, "rb") as f:
|
||||
hasher.update(path.encode("utf-8"))
|
||||
hasher.update(f.read())
|
||||
return hasher.digest()
|
||||
Loading…
Reference in a new issue