pytorch/torch/_inductor/freezing_utils.py
Sam Larsen 2811f33d12 Fix code cache + freezing compile-time regression (#145868)
Summary: The current implementation introduces a compile-time regression due to overhead hashing large constants. To support freezing+caching, we consider only the tensor metadata of frozen params, but we neglect to do the same for any constants created as a result of folding frozen params. This PR Explicitly marks the constants created during freezing (and constant folding during freezing) and uses that info in the inductor cache to determine when to hash a tensor value+metadata vs. metadata only.

Test Plan: `python benchmarks/dynamo/torchbench.py --backend inductor --device cuda --only alexnet --bfloat16 --cold-start-latency --print-compilation-time --inference --performance --freezing`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145868
Approved by: https://github.com/eellison
2025-01-31 02:04:15 +00:00

54 lines
1.2 KiB
Python

import contextlib
import threading
from typing import Any, Generator
import torch
_TLS = threading.local()
def _freezing_active() -> bool:
return getattr(_TLS, "freezing_active", False)
@contextlib.contextmanager
def enter_freezing() -> Generator[Any, None, None]:
"""
Context manager to designate when freezing is active.
"""
prev = _freezing_active()
_TLS.freezing_active = True
try:
yield
finally:
_TLS.freezing_active = prev
def record_has_frozen_params(gm: torch.fx.GraphModule) -> None:
"""
Mark the gm as having frozen params.
"""
gm._has_frozen_params = True # type: ignore[assignment]
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
"""
Return True if the gm has frozen parameters.
"""
return getattr(gm, "_has_frozen_params", False)
def maybe_set_is_frozen_param(t: torch.Tensor) -> None:
"""
Mark the provided tensor as a frozen param if freezing is active.
"""
if _freezing_active():
t._is_frozen_param = True # type: ignore[attr-defined]
def is_frozen_param(t: torch.Tensor) -> bool:
"""
Return True if the tensor is a frozen param.
"""
return getattr(t, "_is_frozen_param", False)