pytorch/torch/_dynamo/variables/__init__.py
2025-02-08 22:42:12 +00:00

198 lines
5.2 KiB
Python

from .base import VariableTracker
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
CatchWarningsCtxManagerVariable,
ContextWrappingVariable,
CUDADeviceVariable,
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
FSDPParamGroupUseTrainingStateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
InferenceModeVariable,
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
TemporarilyPopInterpreterStackCtxManagerVariable,
VmapIncrementNestingCtxManagerVariable,
WithExitFunctionVariable,
)
from .dicts import (
ConstDictVariable,
DefaultDictVariable,
DictKeySetVariable,
FrozensetVariable,
MappingProxyVariable,
NNModuleHooksDictVariable,
SetVariable,
)
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
from .functions import (
BuiltinMethodVariable,
CollectionsNamedTupleFunction,
CreateTMADescriptorVariable,
FunctionDecoratedByContextlibContextManagerVariable,
FunctoolsPartialVariable,
FunctoolsWrapsVariable,
LocalGeneratorFunctionVariable,
LocalGeneratorObjectVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
TMADescriptorVariable,
UserFunctionVariable,
UserMethodVariable,
)
from .higher_order_ops import (
FunctionalCallVariable,
FunctorchHigherOrderVariable,
TorchHigherOrderOperatorVariable,
)
from .iter import (
CountIteratorVariable,
CycleIteratorVariable,
FilterVariable,
IteratorVariable,
ItertoolsVariable,
MapVariable,
RepeatIteratorVariable,
ZipVariable,
)
from .lazy import LazyVariableTracker
from .lists import (
BaseListVariable,
FxImmutableListVariable,
ListIteratorVariable,
ListVariable,
NamedTupleVariable,
RangeVariable,
RestrictedListSubclassVariable,
SliceVariable,
TupleIteratorVariable,
TupleVariable,
)
from .misc import (
AutogradFunctionContextVariable,
AutogradFunctionVariable,
CellVariable,
DeletedVariable,
ExceptionVariable,
GetAttrVariable,
LambdaVariable,
MethodWrapperVariable,
NewGlobalVariable,
NumpyVariable,
PythonModuleVariable,
RandomClassVariable,
RandomVariable,
RegexPatternVariable,
StringFormatVariable,
SuperVariable,
TorchVersionVariable,
TypingVariable,
UnknownVariable,
WeakRefVariable,
)
from .nn_module import (
FSDPManagedNNModuleVariable,
NNModuleVariable,
UnspecializedBuiltinNNModuleVariable,
UnspecializedNNModuleVariable,
)
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
DataPtrVariable,
FakeItemVariable,
NumpyNdarrayVariable,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
UntypedStorageVariable,
)
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .user_defined import (
MutableMappingVariable,
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedObjectVariable,
UserDefinedTupleVariable,
)
__all__ = [
"AutogradFunctionContextVariable",
"AutogradFunctionVariable",
"BackwardHookVariable",
"BaseListVariable",
"BuiltinVariable",
"CatchWarningsCtxManagerVariable",
"ConstantVariable",
"ConstDictVariable",
"ContextWrappingVariable",
"CountIteratorVariable",
"CreateTMADescriptorVariable",
"CUDADeviceVariable",
"CycleIteratorVariable",
"DataPtrVariable",
"DefaultDictVariable",
"DeletedVariable",
"DeterministicAlgorithmsVariable",
"DictKeySetVariable",
"EnumVariable",
"FakeItemVariable",
"GetAttrVariable",
"GradModeVariable",
"IteratorVariable",
"ItertoolsVariable",
"LambdaVariable",
"LazyVariableTracker",
"ListIteratorVariable",
"ListVariable",
"NamedTupleVariable",
"NestedUserFunctionVariable",
"CellVariable",
"NewGlobalVariable",
"NNModuleVariable",
"NumpyNdarrayVariable",
"NumpyVariable",
"OptimizerVariable",
"PlacementVariable",
"PolyfilledFunctionVariable",
"PythonModuleVariable",
"RangeVariable",
"RegexPatternVariable",
"RemovableHandleVariable",
"RepeatIteratorVariable",
"RestrictedListSubclassVariable",
"SDPAParamsVariable",
"SkipFunctionVariable",
"SliceVariable",
"StringFormatVariable",
"SuperVariable",
"TemporarilyPopInterpreterStackCtxManagerVariable",
"TensorVariable",
"TMADescriptorVariable",
"TorchCtxManagerClassVariable",
"TorchInGraphFunctionVariable",
"TorchVersionVariable",
"TupleVariable",
"UnknownVariable",
"UnspecializedNNModuleVariable",
"UnspecializedPythonVariable",
"UntypedStorageVariable",
"UserDefinedClassVariable",
"UserDefinedTupleVariable",
"UserDefinedObjectVariable",
"UserFunctionVariable",
"UserMethodVariable",
"VariableTracker",
"WithExitFunctionVariable",
"MappingProxyVariable",
]