mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: The current implementation introduces a compile-time regression due to overhead hashing large constants. To support freezing+caching, we consider only the tensor metadata of frozen params, but we neglect to do the same for any constants created as a result of folding frozen params. This PR Explicitly marks the constants created during freezing (and constant folding during freezing) and uses that info in the inductor cache to determine when to hash a tensor value+metadata vs. metadata only. Test Plan: `python benchmarks/dynamo/torchbench.py --backend inductor --device cuda --only alexnet --bfloat16 --cold-start-latency --print-compilation-time --inference --performance --freezing` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145868 Approved by: https://github.com/eellison
54 lines
1.2 KiB
Python
54 lines
1.2 KiB
Python
import contextlib
|
|
import threading
|
|
from typing import Any, Generator
|
|
|
|
import torch
|
|
|
|
|
|
_TLS = threading.local()
|
|
|
|
|
|
def _freezing_active() -> bool:
|
|
return getattr(_TLS, "freezing_active", False)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enter_freezing() -> Generator[Any, None, None]:
|
|
"""
|
|
Context manager to designate when freezing is active.
|
|
"""
|
|
prev = _freezing_active()
|
|
_TLS.freezing_active = True
|
|
try:
|
|
yield
|
|
finally:
|
|
_TLS.freezing_active = prev
|
|
|
|
|
|
def record_has_frozen_params(gm: torch.fx.GraphModule) -> None:
|
|
"""
|
|
Mark the gm as having frozen params.
|
|
"""
|
|
gm._has_frozen_params = True # type: ignore[assignment]
|
|
|
|
|
|
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
|
|
"""
|
|
Return True if the gm has frozen parameters.
|
|
"""
|
|
return getattr(gm, "_has_frozen_params", False)
|
|
|
|
|
|
def maybe_set_is_frozen_param(t: torch.Tensor) -> None:
|
|
"""
|
|
Mark the provided tensor as a frozen param if freezing is active.
|
|
"""
|
|
if _freezing_active():
|
|
t._is_frozen_param = True # type: ignore[attr-defined]
|
|
|
|
|
|
def is_frozen_param(t: torch.Tensor) -> bool:
|
|
"""
|
|
Return True if the tensor is a frozen param.
|
|
"""
|
|
return getattr(t, "_is_frozen_param", False)
|