From 145fd5bad0cd16141fc0004d96c6c52f9759e09f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 19 Dec 2024 23:22:43 +0000 Subject: [PATCH] Revert "[Dynamo] only import einops if version is lower than 0.7.0 (#142847)" This reverts commit a96387a481633389a6b5a5ac7b8406e9216f320e. Reverted https://github.com/pytorch/pytorch/pull/142847 on behalf of https://github.com/huydhn due to This has been reverted internally D67436053 ([comment](https://github.com/pytorch/pytorch/pull/142847#issuecomment-2555942351)) --- torch/_dynamo/decorators.py | 46 +++++++++++++++---------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 2a4e004a6ca..893d4e139f1 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -2,12 +2,10 @@ # ruff: noqa: TCH004 import functools import inspect -import sys from dataclasses import dataclass from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar import torch -from torch._vendor.packaging.version import Version from torch.utils._contextlib import _DecoratorContextManager from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -610,33 +608,27 @@ def mark_static_address(t, guard=True): # Note: this carefully avoids eagerly import einops. # TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2 def _allow_in_graph_einops(): - mod = sys.modules.get("einops") - if mod is None: - return - else: - # version > 0.7.0 does allow_in_graph out of tree - if Version(mod.__version__) < Version("0.7.0"): - import einops + import einops - try: - # requires einops > 0.6.1, torch >= 2.0 - from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401 - _ops_were_registered_in_torchdynamo, - ) + try: + # requires einops > 0.6.1, torch >= 2.0 + from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401 + _ops_were_registered_in_torchdynamo, + ) - # einops > 0.6.1 will call the op registration logic as it is imported. - except ImportError: - # einops <= 0.6.1 - allow_in_graph(einops.rearrange) - allow_in_graph(einops.reduce) - if hasattr(einops, "repeat"): - allow_in_graph(einops.repeat) # available since einops 0.2.0 - if hasattr(einops, "einsum"): - allow_in_graph(einops.einsum) # available since einops 0.5.0 - if hasattr(einops, "pack"): - allow_in_graph(einops.pack) # available since einops 0.6.0 - if hasattr(einops, "unpack"): - allow_in_graph(einops.unpack) # available since einops 0.6.0 + # einops > 0.6.1 will call the op registration logic as it is imported. + except ImportError: + # einops <= 0.6.1 + allow_in_graph(einops.rearrange) + allow_in_graph(einops.reduce) + if hasattr(einops, "repeat"): + allow_in_graph(einops.repeat) # available since einops 0.2.0 + if hasattr(einops, "einsum"): + allow_in_graph(einops.einsum) # available since einops 0.5.0 + if hasattr(einops, "pack"): + allow_in_graph(einops.pack) # available since einops 0.6.0 + if hasattr(einops, "unpack"): + allow_in_graph(einops.unpack) # available since einops 0.6.0 trace_rules.add_module_init_func("einops", _allow_in_graph_einops)