2024-11-27 14:04:14 +00:00
|
|
|
"""
|
|
|
|
|
This provides an abstract class which parametrizes over an "output code" concept
|
|
|
|
|
for Inductor. Intuitively, this represents the compiled callable which Inductor
|
|
|
|
|
produces which you can call to get optimized code. However, this callable
|
|
|
|
|
has some other capabilities:
|
|
|
|
|
|
|
|
|
|
- It is serializable, so you can save/load this product from disk without
|
|
|
|
|
having to do compilation again.
|
|
|
|
|
|
|
|
|
|
- (When using remote cache) it is addressable, so you can save just a key
|
|
|
|
|
which you can use to load this product from remote cache later.
|
|
|
|
|
|
|
|
|
|
This class is abstract because we have several different implementations of
|
|
|
|
|
serialized format:
|
|
|
|
|
|
|
|
|
|
- Python wrapper (the default)
|
|
|
|
|
|
|
|
|
|
- AOTInductor (this produces ABI stable binaries which work across PyTorch
|
|
|
|
|
versions)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import dataclasses
|
2024-12-03 00:26:07 +00:00
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
from pathlib import Path
|
2025-01-22 02:51:27 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
2024-11-27 14:04:14 +00:00
|
|
|
from typing_extensions import TypeAlias
|
|
|
|
|
|
2024-11-28 04:43:07 +00:00
|
|
|
import torch
|
2024-12-13 17:37:20 +00:00
|
|
|
from torch._dynamo.utils import counters, get_runtime_metrics_context
|
2024-11-28 04:43:07 +00:00
|
|
|
from torch._inductor.cudagraph_utils import (
|
|
|
|
|
BoxedDeviceIndex,
|
|
|
|
|
CudagraphCachedInfo,
|
|
|
|
|
get_placeholder_info,
|
|
|
|
|
log_cudagraph_skip_and_bump_counter,
|
|
|
|
|
)
|
2025-01-30 03:10:00 +00:00
|
|
|
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
|
2024-12-04 02:06:09 +00:00
|
|
|
from torch._inductor.utils import (
|
|
|
|
|
align_inputs_from_check_idxs,
|
|
|
|
|
BoxedBool,
|
|
|
|
|
InputType,
|
|
|
|
|
output_node,
|
|
|
|
|
set_tracing_context_output_strides,
|
|
|
|
|
)
|
2024-12-13 12:37:22 +00:00
|
|
|
from torch.utils._ordered_set import OrderedSet
|
2024-11-28 04:43:07 +00:00
|
|
|
|
|
|
|
|
from . import config
|
2024-11-27 14:04:14 +00:00
|
|
|
from .runtime.autotune_cache import AutotuneCacheBundler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
2025-01-20 20:27:30 +00:00
|
|
|
from collections import Counter
|
|
|
|
|
from collections.abc import Sequence
|
|
|
|
|
|
2024-11-27 14:04:14 +00:00
|
|
|
from torch._inductor import metrics
|
|
|
|
|
from torch._inductor.graph import GraphLowering
|
|
|
|
|
|
|
|
|
|
from .compile_fx import _CompileFxKwargs
|
|
|
|
|
from .triton_bundler import TritonKernelArtifacts
|
|
|
|
|
|
2024-12-03 00:26:07 +00:00
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
2024-11-27 14:04:14 +00:00
|
|
|
|
2024-12-02 15:46:45 +00:00
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class OutputCode:
|
|
|
|
|
# TODO: Remove underscores here
|
|
|
|
|
|
|
|
|
|
# None if the output is not remote cacheable
|
|
|
|
|
_fx_graph_cache_key: Optional[str] = dataclasses.field(default=None, init=False)
|
|
|
|
|
|
|
|
|
|
# How long it took to compile this OutputCode, end to end
|
|
|
|
|
_time_taken_ns: Optional[int] = dataclasses.field(default=None, init=False)
|
|
|
|
|
|
2024-11-27 14:04:14 +00:00
|
|
|
def __call__(self, inputs: Sequence[Any]) -> Any:
|
2024-12-02 15:46:45 +00:00
|
|
|
raise NotImplementedError(type(self))
|
2024-11-27 14:04:14 +00:00
|
|
|
|
2024-11-28 14:18:37 +00:00
|
|
|
def post_compile(
|
2024-11-28 04:43:07 +00:00
|
|
|
self,
|
|
|
|
|
example_inputs: Sequence[InputType],
|
|
|
|
|
cudagraphs: BoxedBool,
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
constants: CompiledFxGraphConstants,
|
2024-11-28 04:43:07 +00:00
|
|
|
) -> None:
|
2024-12-02 15:46:45 +00:00
|
|
|
raise NotImplementedError(type(self))
|
2024-11-28 14:18:38 +00:00
|
|
|
|
|
|
|
|
# TODO: Get rid of this
|
|
|
|
|
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
2024-12-02 15:46:45 +00:00
|
|
|
raise NotImplementedError(type(self))
|
2024-11-28 14:18:38 +00:00
|
|
|
|
2024-11-27 14:04:14 +00:00
|
|
|
|
|
|
|
|
_StrideExprStr: TypeAlias = str
|
|
|
|
|
|
|
|
|
|
|
2024-11-28 04:43:07 +00:00
|
|
|
# copy_ fails when trying to write to tensors with memory overlap,
|
|
|
|
|
# for expanded dimensions (a dimension which used to have size 1 -> ?)
|
|
|
|
|
# we can select one element from that dimension and write to it
|
|
|
|
|
# to achieve writing to all values of that dimension of the input tensor
|
2025-01-20 20:27:30 +00:00
|
|
|
def get_expanded_dims(t: torch.Tensor) -> list[int]:
|
2024-11-28 04:43:07 +00:00
|
|
|
if not isinstance(t, torch.Tensor):
|
|
|
|
|
return None
|
|
|
|
|
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
|
|
|
|
|
|
|
|
|
|
|
2025-01-20 20:27:30 +00:00
|
|
|
def index_expanded_dims(t: torch.Tensor, expanded_dims: list[int]) -> torch.Tensor:
|
2024-11-28 04:43:07 +00:00
|
|
|
for expanded_dim in expanded_dims:
|
|
|
|
|
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def complex_memory_overlap(t: torch.Tensor) -> bool:
|
|
|
|
|
if config.always_complex_memory_overlap_TESTING_ONLY:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# if torch._debug_has_internal_overlap thinks this tensor potentially has
|
|
|
|
|
# memory overlap internally, let's dig deeper to find out whether it's true.
|
|
|
|
|
#
|
|
|
|
|
# Call squeeze() so that dimension with size 1 does not cause false positive.
|
|
|
|
|
t = index_expanded_dims(t, get_expanded_dims(t)).squeeze()
|
|
|
|
|
if torch._debug_has_internal_overlap(t) != 0:
|
|
|
|
|
strides = t.stride()
|
|
|
|
|
sizes = t.shape
|
|
|
|
|
indices = list(range(len(strides)))
|
|
|
|
|
indices = [x for _, x in sorted(zip(strides, indices))]
|
|
|
|
|
for i in range(len(strides)):
|
|
|
|
|
prev_stride = 1 if i == 0 else strides[indices[i - 1]]
|
|
|
|
|
prev_size = 1 if i == 0 else sizes[indices[i - 1]]
|
|
|
|
|
if strides[indices[i]] < prev_stride * prev_size:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
2024-12-04 02:06:09 +00:00
|
|
|
def cudagraph_post_compile(
|
|
|
|
|
example_inputs: Sequence[InputType],
|
|
|
|
|
compiled_graph: CompiledFxGraph,
|
|
|
|
|
cudagraphs: BoxedBool,
|
2025-01-20 20:27:30 +00:00
|
|
|
constants: dict[str, torch.Tensor],
|
2024-12-04 02:06:09 +00:00
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Checks for any reasons not to run cudagraphs and then
|
|
|
|
|
runs it on compiled_graph.
|
|
|
|
|
Mutates the `compiled_graph.current_callable` and `cudagraphs`
|
|
|
|
|
"""
|
|
|
|
|
assert compiled_graph.current_callable is not None
|
|
|
|
|
assert compiled_graph.cudagraph_info is not None
|
|
|
|
|
cached_info = compiled_graph.cudagraph_info
|
|
|
|
|
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
|
|
|
|
|
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
|
|
|
|
|
is_inference = compiled_graph.fx_kwargs["is_inference"]
|
|
|
|
|
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
|
|
|
|
|
|
|
|
|
if not cudagraph_fail_reasons:
|
|
|
|
|
fx_kwargs = compiled_graph.fx_kwargs
|
|
|
|
|
static_input_idxs = fx_kwargs["static_input_idxs"]
|
|
|
|
|
|
|
|
|
|
placeholders = cached_info.placeholders
|
|
|
|
|
stack_traces = cached_info.stack_traces
|
|
|
|
|
if not config.triton.cudagraph_trees:
|
|
|
|
|
# Force specialize all inputs so that CUDA graphs will work
|
|
|
|
|
for t in example_inputs:
|
|
|
|
|
if isinstance(t, torch.SymInt):
|
|
|
|
|
int(t) # guard
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
boxed_forward_device_index is not None
|
|
|
|
|
and not is_inference
|
|
|
|
|
and not is_backward
|
|
|
|
|
):
|
|
|
|
|
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
|
|
|
|
|
|
|
|
|
from .compile_fx import cudagraphify
|
|
|
|
|
|
|
|
|
|
current_callable = compiled_graph.current_callable
|
|
|
|
|
assert current_callable is not None
|
|
|
|
|
compiled_graph.current_callable = cudagraphify(
|
|
|
|
|
current_callable,
|
|
|
|
|
static_input_idxs=static_input_idxs or (),
|
|
|
|
|
device_index=next(iter(compiled_graph.device_idxs)),
|
|
|
|
|
stack_traces=stack_traces,
|
|
|
|
|
is_backward=is_backward,
|
|
|
|
|
is_inference=is_inference,
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
constants=tuple(constants.values()),
|
2024-12-04 02:06:09 +00:00
|
|
|
placeholders=placeholders,
|
|
|
|
|
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
BoxedBool.disable(cudagraphs)
|
|
|
|
|
|
|
|
|
|
# See [Backward Generation Handling]
|
|
|
|
|
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
|
|
|
|
# know we are we running the backward even if we will not run it in cudagraphs
|
|
|
|
|
if is_backward and config.triton.cudagraph_trees:
|
|
|
|
|
assert boxed_forward_device_index is not None
|
|
|
|
|
assert boxed_forward_device_index.value is not None
|
|
|
|
|
compiled_graph_callable = compiled_graph.current_callable
|
|
|
|
|
|
|
|
|
|
manager = torch._inductor.cudagraph_trees.get_manager(
|
|
|
|
|
boxed_forward_device_index.value, create_if_none_exists=False
|
|
|
|
|
)
|
|
|
|
|
# should already exist from forward
|
|
|
|
|
assert manager is not None
|
|
|
|
|
|
2025-01-20 20:27:30 +00:00
|
|
|
def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]:
|
2024-12-04 02:06:09 +00:00
|
|
|
manager.set_to_running_backward() # type: ignore[union-attr]
|
|
|
|
|
return compiled_graph_callable(new_inputs)
|
|
|
|
|
|
|
|
|
|
compiled_graph.current_callable = compiled_artifact
|
|
|
|
|
|
|
|
|
|
if "cuda" in compiled_graph.device_types:
|
|
|
|
|
# prefer better disable_cudagraphs_reason bc stack trace
|
|
|
|
|
# TODO: migrate all disable reasons to stack trace, refactor
|
|
|
|
|
if compiled_graph.disabled_cudagraphs_reason:
|
|
|
|
|
log_cudagraph_skip_and_bump_counter(
|
|
|
|
|
compiled_graph.disabled_cudagraphs_reason
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
log_cudagraph_skip_and_bump_counter(
|
|
|
|
|
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def maybe_realign_inputs(
|
|
|
|
|
ran_cudagraphs: BoxedBool,
|
|
|
|
|
compiled_graph: CompiledFxGraph,
|
|
|
|
|
inputs_to_check: Sequence[int],
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Realigns input strides from inputs_to_check if
|
|
|
|
|
we didn't end up running cudagraphs. Mutates
|
|
|
|
|
`compiled_graph.current_callable` if cudagraphs
|
|
|
|
|
was run. Otherwise, does nothing.
|
|
|
|
|
"""
|
|
|
|
|
if not ran_cudagraphs:
|
|
|
|
|
assert compiled_graph.current_callable is not None
|
|
|
|
|
new_callable = align_inputs_from_check_idxs(
|
|
|
|
|
compiled_graph.current_callable, inputs_to_check
|
|
|
|
|
)
|
|
|
|
|
if new_callable is not compiled_graph.current_callable:
|
|
|
|
|
compiled_graph.current_callable = new_callable
|
|
|
|
|
|
|
|
|
|
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
class CompiledFxGraphConstants:
|
|
|
|
|
"""Wrapper class that unwraps constants from a compiled fx graph. This
|
|
|
|
|
version of the class only supports directly grabbing the saved constants off of
|
|
|
|
|
a CompiledFxGraph.
|
|
|
|
|
|
|
|
|
|
With freezing, FxGraphCache doesn't store the constants of the input
|
|
|
|
|
GraphModule it gets from AOTAutograd. Instead, it saves just the **names**
|
|
|
|
|
of those constants, and grabs the constant values directly from the graph module
|
|
|
|
|
passed in at runtime.
|
|
|
|
|
|
|
|
|
|
Thing is, we don't always *have* the graph module available at runtime, hence
|
|
|
|
|
the existence of this class and its CompiledFxGraphConstantsWithGm counterpart.
|
|
|
|
|
|
|
|
|
|
To support freezing, FXGraphCache gets passed a CompiledFxGraphConstantsWithGm during
|
|
|
|
|
post compile. Otherwise, CompiledFxGraphConstants supports the basic case of loading
|
|
|
|
|
the value of constants directly off of the original saved object.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-01-20 20:27:30 +00:00
|
|
|
def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]:
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
assert g.constants is not None
|
|
|
|
|
return g.constants
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants):
|
|
|
|
|
"""
|
|
|
|
|
This version of CompiledFxGraphConstants, instead of grabbing constants
|
|
|
|
|
directly saved on CompiledFxGraphs, will just grab their names. Then, it takes
|
|
|
|
|
a second GraphModule to grab the corresponding constant values out of.
|
|
|
|
|
|
|
|
|
|
This is necessary for supporting freezing in FxGraphCache.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, gm: torch.fx.GraphModule) -> None:
|
|
|
|
|
self.gm = gm
|
|
|
|
|
|
2025-01-22 02:51:27 +00:00
|
|
|
def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]:
|
|
|
|
|
frozen_params = {
|
|
|
|
|
name: getattr(self.gm, orig_name)
|
|
|
|
|
for name, orig_name in g.frozen_param_names.items()
|
|
|
|
|
}
|
|
|
|
|
constants = g.constants or {}
|
|
|
|
|
return {**constants, **frozen_params}
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
|
|
|
|
|
|
2024-11-27 14:04:14 +00:00
|
|
|
@dataclasses.dataclass
|
2024-12-02 15:46:45 +00:00
|
|
|
class CompiledFxGraph(OutputCode):
|
2024-11-27 14:04:14 +00:00
|
|
|
"""
|
|
|
|
|
Class holding a compiled FX graph. This is the object serialized on disk
|
|
|
|
|
to support FxGraph caching.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
current_callable: Optional[Callable[..., Any]]
|
|
|
|
|
cache_key: str
|
|
|
|
|
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
2025-01-20 20:27:30 +00:00
|
|
|
cache_linemap: Optional[list[tuple[int, str]]]
|
2024-12-13 12:37:22 +00:00
|
|
|
device_types: OrderedSet[str]
|
|
|
|
|
device_idxs: OrderedSet[int]
|
|
|
|
|
mutated_inputs: OrderedSet[str]
|
|
|
|
|
mutated_input_idxs: OrderedSet[int]
|
2025-01-22 02:51:27 +00:00
|
|
|
constants: Optional[Dict[str, torch.Tensor]]
|
|
|
|
|
frozen_param_names: Dict[str, str]
|
|
|
|
|
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
|
|
|
|
output_strides: Optional[List[Optional[tuple[_StrideExprStr, ...]]]]
|
2024-11-27 14:04:14 +00:00
|
|
|
disabled_cudagraphs_reason: Optional[str]
|
|
|
|
|
metrics_deltas: metrics.CachedMetricsDeltas
|
|
|
|
|
counter_deltas: Counter[str]
|
|
|
|
|
# This is a string representation of an expression we serialize
|
|
|
|
|
# with the object so the guards can be evaluated in a different
|
|
|
|
|
# context in order to verify the validity of serving a cached
|
|
|
|
|
# fx graph. The expression must be generated by:
|
|
|
|
|
# ShapeEnv.produce_guards_expression()
|
|
|
|
|
guards_expr: Optional[str]
|
|
|
|
|
|
|
|
|
|
cudagraph_info: Optional[CudagraphCachedInfo]
|
|
|
|
|
fx_kwargs: _CompileFxKwargs
|
|
|
|
|
inputs_to_check: Sequence[int]
|
|
|
|
|
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
|
|
|
|
|
|
|
|
|
_boxed_call: Optional[bool] = None
|
2025-01-20 20:27:30 +00:00
|
|
|
_triton_bundle: Optional[list[TritonKernelArtifacts]] = None
|
2024-11-27 14:04:14 +00:00
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
current_callable: Optional[Callable[..., Any]],
|
|
|
|
|
graph: GraphLowering,
|
|
|
|
|
gm: torch.fx.GraphModule,
|
2025-01-20 20:27:30 +00:00
|
|
|
output_strides: list[Optional[tuple[_StrideExprStr, ...]]],
|
2024-11-27 14:04:14 +00:00
|
|
|
disabled_cudagraphs_reason: Optional[str],
|
|
|
|
|
metrics_deltas: metrics.CachedMetricsDeltas,
|
|
|
|
|
counter_deltas: Counter[str],
|
2024-11-28 14:18:37 +00:00
|
|
|
cudagraphs: BoxedBool,
|
|
|
|
|
example_inputs: Sequence[InputType],
|
|
|
|
|
static_input_idxs: Sequence[int],
|
|
|
|
|
fx_kwargs: _CompileFxKwargs,
|
|
|
|
|
inputs_to_check: Sequence[int],
|
|
|
|
|
boxed_forward_device_index: Optional[BoxedDeviceIndex],
|
2024-11-27 14:04:14 +00:00
|
|
|
) -> None:
|
|
|
|
|
self.current_callable = current_callable
|
|
|
|
|
self.cache_key = graph.cache_key
|
|
|
|
|
if graph.cache_path:
|
|
|
|
|
with open(graph.cache_path) as f:
|
|
|
|
|
self.source_code = f.read()
|
|
|
|
|
self.cache_linemap = graph.cache_linemap
|
|
|
|
|
# TODO - ordered set
|
2024-12-13 12:37:22 +00:00
|
|
|
self.device_types = OrderedSet(graph.device_types)
|
|
|
|
|
self.device_idxs = OrderedSet(graph.device_idxs)
|
|
|
|
|
self.mutated_inputs = OrderedSet(graph.mutated_inputs)
|
|
|
|
|
self.mutated_input_idxs = OrderedSet(graph.mutated_input_idxs)
|
2025-01-22 02:51:27 +00:00
|
|
|
|
|
|
|
|
# We store the constant attributes in the cache entry and re-attach them
|
|
|
|
|
# to the module created in PyCodeCache.load_by_key_path. In the case that
|
|
|
|
|
# the graph has frozen parameters, we save the mapping from the attribute
|
|
|
|
|
# names in the GraphLowering to the original name of the attribute in the
|
|
|
|
|
# GraphModule. When we create the module from the cache entry, we then
|
|
|
|
|
# look up the constants from the current GraphModule. This scheme allows
|
|
|
|
|
# us to support caching with freezing.
|
|
|
|
|
if not has_frozen_params(gm):
|
2024-11-27 14:04:14 +00:00
|
|
|
self.constants = graph.constants
|
2025-01-22 02:51:27 +00:00
|
|
|
self.frozen_param_names = {}
|
|
|
|
|
else:
|
|
|
|
|
self.constants = {}
|
|
|
|
|
self.frozen_param_names = {}
|
|
|
|
|
for k, v in graph.constants.items():
|
2025-01-30 03:10:00 +00:00
|
|
|
if is_frozen_param(v):
|
2025-01-22 02:51:27 +00:00
|
|
|
self.frozen_param_names[k] = graph.allocated_constant_name[k]
|
|
|
|
|
else:
|
|
|
|
|
self.constants[k] = v
|
|
|
|
|
|
2024-11-27 14:04:14 +00:00
|
|
|
self.torchbind_constants = graph.torchbind_constants
|
|
|
|
|
self.output_strides = output_strides
|
|
|
|
|
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
|
|
|
|
|
self.metrics_deltas = metrics_deltas
|
|
|
|
|
self.counter_deltas = counter_deltas
|
|
|
|
|
self.guards_expr = None
|
|
|
|
|
self.cudagraph_info = None
|
|
|
|
|
self.fx_kwargs = {}
|
|
|
|
|
self.inputs_to_check = ()
|
|
|
|
|
self.boxed_forward_device_index = None
|
|
|
|
|
|
2024-11-28 04:43:07 +00:00
|
|
|
cudagraph_info = None
|
|
|
|
|
if cudagraphs:
|
|
|
|
|
# check cudagraph disabling reasons from inductor lowering
|
|
|
|
|
if self.disabled_cudagraphs_reason:
|
|
|
|
|
if "cuda" in self.device_types:
|
|
|
|
|
log_cudagraph_skip_and_bump_counter(
|
|
|
|
|
f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
counters["inductor"]["cudagraph_skips"] += 1
|
|
|
|
|
BoxedBool.disable(cudagraphs)
|
|
|
|
|
else:
|
|
|
|
|
complex_memory_overlap_inputs = any(
|
|
|
|
|
complex_memory_overlap(t)
|
|
|
|
|
for t in example_inputs
|
|
|
|
|
if isinstance(t, torch.Tensor)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not config.triton.cudagraph_support_input_mutation:
|
|
|
|
|
# Skip supports for cudagraph-managed tensors
|
|
|
|
|
from torch._inductor.cudagraph_utils import (
|
|
|
|
|
check_for_mutation_ignore_cuda_graph_managed_tensor,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
has_mutation_str = (
|
|
|
|
|
check_for_mutation_ignore_cuda_graph_managed_tensor(
|
|
|
|
|
gm,
|
2025-01-22 14:54:45 +00:00
|
|
|
self.mutated_inputs,
|
|
|
|
|
self.mutated_input_idxs,
|
2024-11-28 04:43:07 +00:00
|
|
|
static_input_idxs,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
has_mutation = has_mutation_str is not None
|
|
|
|
|
|
|
|
|
|
if has_mutation:
|
|
|
|
|
self.disabled_cudagraphs_reason = has_mutation_str
|
|
|
|
|
else:
|
|
|
|
|
# Check mutation later to support cudagraph-managed tensors
|
|
|
|
|
has_mutation = None
|
|
|
|
|
|
|
|
|
|
cudagraph_tests = [
|
|
|
|
|
(not has_mutation, "mutated inputs"),
|
|
|
|
|
(not complex_memory_overlap_inputs, "complex memory overlap"),
|
|
|
|
|
(
|
|
|
|
|
all(
|
|
|
|
|
isinstance(t, (torch.Tensor, torch.SymInt))
|
|
|
|
|
for t in example_inputs
|
|
|
|
|
),
|
|
|
|
|
"non-Tensor inputs",
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
output = output_node(gm)
|
|
|
|
|
# output args are tuple of first argument
|
|
|
|
|
assert len(output.args) == 1
|
|
|
|
|
stack_traces = [
|
|
|
|
|
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
|
|
|
|
for arg in output.args[0]
|
|
|
|
|
]
|
|
|
|
|
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
|
|
|
|
|
placeholders = tuple(get_placeholder_info(gm.graph))
|
|
|
|
|
cudagraph_info = CudagraphCachedInfo(
|
|
|
|
|
placeholders, stack_traces, cudagraph_fail_reasons
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.cudagraph_info = cudagraph_info
|
|
|
|
|
self.inputs_to_check = inputs_to_check
|
|
|
|
|
self.fx_kwargs = fx_kwargs
|
|
|
|
|
# TODO: should this be part of fx_kwargs
|
|
|
|
|
self.boxed_forward_device_index = boxed_forward_device_index
|
|
|
|
|
|
2024-12-04 02:06:09 +00:00
|
|
|
# aot autograd needs to know to pass in inputs as a list
|
|
|
|
|
self._boxed_call = True
|
|
|
|
|
|
2024-11-28 14:18:37 +00:00
|
|
|
def __call__(self, inputs: Sequence[Any]) -> Any:
|
|
|
|
|
assert self.current_callable is not None
|
|
|
|
|
try:
|
2025-01-27 22:39:30 +00:00
|
|
|
return self.current_callable(inputs)
|
2024-11-28 14:18:37 +00:00
|
|
|
finally:
|
2024-12-13 17:37:20 +00:00
|
|
|
get_runtime_metrics_context().finish()
|
2024-11-28 14:18:37 +00:00
|
|
|
AutotuneCacheBundler.end_compile()
|
|
|
|
|
|
|
|
|
|
def post_compile(
|
2024-11-28 04:43:07 +00:00
|
|
|
self,
|
|
|
|
|
example_inputs: Sequence[InputType],
|
|
|
|
|
cudagraphs: BoxedBool,
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
constants: CompiledFxGraphConstants,
|
2024-11-28 04:43:07 +00:00
|
|
|
) -> None:
|
2024-12-04 02:06:09 +00:00
|
|
|
"""
|
|
|
|
|
Run a set of post processing steps after loading from the cache. These involve:
|
|
|
|
|
- Setting the tracing context output strides
|
|
|
|
|
- Running cudagraphs if enabled
|
|
|
|
|
- Realigning inputs
|
2024-11-28 04:43:07 +00:00
|
|
|
|
2024-12-04 02:06:09 +00:00
|
|
|
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
|
|
|
|
|
The results of this function are *not* saved in the cache itself.
|
|
|
|
|
"""
|
|
|
|
|
set_tracing_context_output_strides(example_inputs, self)
|
2024-11-28 04:43:07 +00:00
|
|
|
|
2024-12-04 02:06:09 +00:00
|
|
|
if cudagraphs:
|
|
|
|
|
# It's possible that cudagraphs is enabled, but was disabled
|
|
|
|
|
# during a previous compilation we're loading from the cache.
|
|
|
|
|
# If so, we need to disable it on this new process too.
|
|
|
|
|
if self.disabled_cudagraphs_reason:
|
|
|
|
|
if "cuda" in self.device_types:
|
|
|
|
|
log_cudagraph_skip_and_bump_counter(
|
|
|
|
|
f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
counters["inductor"]["cudagraph_skips"] += 1
|
|
|
|
|
BoxedBool.disable(cudagraphs)
|
|
|
|
|
else:
|
|
|
|
|
cudagraph_post_compile(
|
|
|
|
|
example_inputs,
|
|
|
|
|
self,
|
|
|
|
|
cudagraphs,
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
constants.unwrap(self),
|
2024-12-04 02:06:09 +00:00
|
|
|
)
|
|
|
|
|
inputs_to_check = self.inputs_to_check
|
|
|
|
|
# cudagraphs could have been disabled from the earlier conditions
|
|
|
|
|
# so we still need to realign inputs if that happens
|
|
|
|
|
maybe_realign_inputs(
|
|
|
|
|
cudagraphs,
|
|
|
|
|
self,
|
|
|
|
|
inputs_to_check,
|
|
|
|
|
)
|
2024-11-28 14:18:37 +00:00
|
|
|
|
2024-11-28 14:18:38 +00:00
|
|
|
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
|
|
|
|
self._triton_bundle = triton_bundle
|
|
|
|
|
|
2024-12-03 00:26:07 +00:00
|
|
|
def prepare_for_serialization(self) -> None:
|
|
|
|
|
# We can't really serialize callables that may be C++/Triton/etc.,
|
|
|
|
|
# so we serialize their PyCodeCache disk cache location instead.
|
|
|
|
|
# TODO: This could be better if we're ever able to serialize compiled
|
|
|
|
|
# models to disk.
|
|
|
|
|
self.current_callable = None
|
|
|
|
|
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
def after_deserialization(self, constants: CompiledFxGraphConstants) -> str:
|
2024-12-03 00:26:07 +00:00
|
|
|
from torch._dynamo.utils import counters, dynamo_timed
|
|
|
|
|
from torch._inductor.codecache import (
|
|
|
|
|
cpp_prefix_path,
|
|
|
|
|
get_path,
|
|
|
|
|
PyCodeCache,
|
|
|
|
|
write_atomic,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# See _save_graph(); we don't store the callable in the cache entry so
|
|
|
|
|
# recreate it here from the PyCodeCache disk cache.
|
|
|
|
|
artifact_path = get_path(self.cache_key, "py")[2]
|
|
|
|
|
code = self.source_code
|
|
|
|
|
if not os.path.exists(artifact_path):
|
|
|
|
|
counters["inductor"]["fxgraph_lookup_write_file"] += 1
|
|
|
|
|
Path(os.path.dirname(artifact_path)).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
cpp_pp = cpp_prefix_path()
|
|
|
|
|
if os.path.basename(cpp_pp) in code:
|
|
|
|
|
if cpp_pp in code:
|
|
|
|
|
# Great the name is correct
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
# Old dir name is included, replace it
|
|
|
|
|
pattern = rf'#include\s*"[^"]+{os.path.basename(cpp_pp)}"'
|
|
|
|
|
code = re.sub(pattern, f'#include "{cpp_pp}"', code)
|
|
|
|
|
self.source_code = code
|
|
|
|
|
|
|
|
|
|
write_atomic(artifact_path, code, make_dirs=True)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
with dynamo_timed(
|
|
|
|
|
"PyCodeCache.load_by_key_path",
|
|
|
|
|
log_pt2_compile_event=True,
|
|
|
|
|
):
|
|
|
|
|
self.current_callable = PyCodeCache.load_by_key_path(
|
|
|
|
|
self.cache_key,
|
|
|
|
|
artifact_path,
|
|
|
|
|
self.cache_linemap,
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
constants.unwrap(self),
|
2024-12-03 00:26:07 +00:00
|
|
|
).call
|
|
|
|
|
except OSError:
|
|
|
|
|
log.error("Failed to load artifact: %s", artifact_path)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
return artifact_path
|
|
|
|
|
|
2024-11-27 14:04:14 +00:00
|
|
|
|
2024-11-28 14:18:38 +00:00
|
|
|
@dataclasses.dataclass
|
2024-12-02 15:46:45 +00:00
|
|
|
class CompiledAOTI(OutputCode):
|
2024-11-28 14:18:38 +00:00
|
|
|
"""
|
|
|
|
|
Class holding an AOTInductor compiled so.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-01-20 20:27:30 +00:00
|
|
|
filename: Union[str, list[str]]
|
2024-11-28 14:18:38 +00:00
|
|
|
|
|
|
|
|
def __call__(self, inputs: Sequence[Any]) -> Any:
|
|
|
|
|
raise NotImplementedError("NYI")
|
|
|
|
|
|
|
|
|
|
def post_compile(
|
|
|
|
|
self,
|
|
|
|
|
example_inputs: Sequence[InputType],
|
|
|
|
|
cudagraphs: BoxedBool,
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
constants: CompiledFxGraphConstants,
|
2024-11-28 14:18:38 +00:00
|
|
|
) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
2024-12-04 02:06:09 +00:00
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class MockFXGraphCacheOutput(OutputCode):
|
|
|
|
|
gm: Any = None
|
|
|
|
|
|
|
|
|
|
def __post_init__(self) -> None:
|
|
|
|
|
self._boxed_call = True
|
|
|
|
|
|
|
|
|
|
def post_compile(
|
|
|
|
|
self,
|
|
|
|
|
example_inputs: Sequence[InputType],
|
|
|
|
|
cudagraphs: BoxedBool,
|
Refactor optional graph module into CompiledFxGraphConstants (#141897)
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
- It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
2024-12-04 16:44:29 +00:00
|
|
|
constants: CompiledFxGraphConstants,
|
2024-12-04 02:06:09 +00:00
|
|
|
) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def __call__(self, inputs: Sequence[Any]) -> Any:
|
|
|
|
|
return self.gm(inputs)
|
|
|
|
|
|
|
|
|
|
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
|
|
|
|
pass
|