[PT2][Runtime Numeric Check] Fix compatibility issue (#118578)

Summary: Titled

Test Plan: WIP

Differential Revision: D53196722

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118578
Approved by: https://github.com/jackiexu1992
This commit is contained in:
Menglu Yu 2024-01-30 08:04:27 +00:00 committed by PyTorch MergeBot
parent b7c8485704
commit 3ecc2f3a0d

View file

@ -69,7 +69,9 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
if config.pattern_matcher:
lazy_init()
if config.fx_passes_numeric_check.get("pre_grad", False):
if hasattr(
config, "fx_passes_numeric_check"
) and config.fx_passes_numeric_check.get("pre_grad", False):
gm_before_fx_passes = gm.__copy__()
# explicitly run with predispatch atenIR based passes
if config.is_predispatch:
@ -94,7 +96,11 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
gm.graph.lint()
gm.recompile()
if config.pattern_matcher and config.fx_passes_numeric_check.get("pre_grad", False):
if (
config.pattern_matcher
and hasattr(config, "fx_passes_numeric_check")
and config.fx_passes_numeric_check.get("pre_grad", False)
):
from .numeric_utils import numeric_check_if_enabled
gm_after_fx_passes = gm.__copy__()