mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add decorator `torch.compiler.substitute_in_graph` to register polyfill for unsupported C++ function to avoid graph break. This API provides an official way to add support for dynamo for third-party C extensions. Also, it can be used to simplify our implementation for `torch._dynamo.polyfill`.
5ee070266f/torch/_dynamo/variables/builtin.py (L97-L107)
Example:
```python
>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
Unsupported: ...
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(sequence, x):
... for i, item in enumerate(sequence):
... if item is x or item == x:
... return i
... raise ValueError("sequence.index(x): x not in sequence")
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133712
Approved by: https://github.com/jansel
175 lines
4.5 KiB
Python
175 lines
4.5 KiB
Python
from .base import VariableTracker
|
|
from .builtin import BuiltinVariable
|
|
from .constant import ConstantVariable, EnumVariable
|
|
from .ctx_manager import (
|
|
CatchWarningsCtxManagerVariable,
|
|
ContextWrappingVariable,
|
|
CUDADeviceVariable,
|
|
DeterministicAlgorithmsVariable,
|
|
DisabledSavedTensorsHooksVariable,
|
|
DualLevelContextManager,
|
|
FSDPParamGroupUseTrainingStateVariable,
|
|
GradIncrementNestingCtxManagerVariable,
|
|
GradInplaceRequiresGradCtxManagerVariable,
|
|
GradModeVariable,
|
|
InferenceModeVariable,
|
|
JvpIncrementNestingCtxManagerVariable,
|
|
SetFwdGradEnabledContextManager,
|
|
StreamContextVariable,
|
|
StreamVariable,
|
|
VmapIncrementNestingCtxManagerVariable,
|
|
WithExitFunctionVariable,
|
|
)
|
|
from .dicts import (
|
|
ConstDictVariable,
|
|
CustomizedDictVariable,
|
|
DefaultDictVariable,
|
|
SetVariable,
|
|
)
|
|
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
|
|
from .functions import (
|
|
FunctoolsPartialVariable,
|
|
NestedUserFunctionVariable,
|
|
PolyfilledFunctionVariable,
|
|
SkipFunctionVariable,
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
)
|
|
from .higher_order_ops import (
|
|
FunctionalCallVariable,
|
|
FunctorchHigherOrderVariable,
|
|
TorchHigherOrderOperatorVariable,
|
|
)
|
|
from .iter import (
|
|
CountIteratorVariable,
|
|
CycleIteratorVariable,
|
|
IteratorVariable,
|
|
ItertoolsVariable,
|
|
RepeatIteratorVariable,
|
|
)
|
|
from .lazy import LazyVariableTracker
|
|
from .lists import (
|
|
BaseListVariable,
|
|
ListIteratorVariable,
|
|
ListVariable,
|
|
NamedTupleVariable,
|
|
RangeVariable,
|
|
RestrictedListSubclassVariable,
|
|
SliceVariable,
|
|
TupleIteratorVariable,
|
|
TupleVariable,
|
|
)
|
|
from .misc import (
|
|
AutogradFunctionContextVariable,
|
|
AutogradFunctionVariable,
|
|
ClosureVariable,
|
|
DeletedVariable,
|
|
ExceptionVariable,
|
|
GetAttrVariable,
|
|
InspectSignatureVariable,
|
|
LambdaVariable,
|
|
MethodWrapperVariable,
|
|
NewCellVariable,
|
|
NewGlobalVariable,
|
|
NumpyVariable,
|
|
PythonModuleVariable,
|
|
RandomClassVariable,
|
|
RandomVariable,
|
|
RegexPatternVariable,
|
|
StringFormatVariable,
|
|
SuperVariable,
|
|
TorchVersionVariable,
|
|
TypingVariable,
|
|
UnknownVariable,
|
|
)
|
|
from .nn_module import (
|
|
NNModuleVariable,
|
|
UnspecializedBuiltinNNModuleVariable,
|
|
UnspecializedNNModuleVariable,
|
|
)
|
|
from .optimizer import OptimizerVariable
|
|
from .sdpa import SDPAParamsVariable
|
|
from .tensor import (
|
|
FakeItemVariable,
|
|
NumpyNdarrayVariable,
|
|
SymNodeVariable,
|
|
TensorVariable,
|
|
UnspecializedPythonVariable,
|
|
UntypedStorageVariable,
|
|
)
|
|
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
|
from .user_defined import (
|
|
MutableMappingVariable,
|
|
RemovableHandleVariable,
|
|
UserDefinedClassVariable,
|
|
UserDefinedObjectVariable,
|
|
WeakRefVariable,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"AutogradFunctionContextVariable",
|
|
"AutogradFunctionVariable",
|
|
"BackwardHookVariable",
|
|
"BaseListVariable",
|
|
"BuiltinVariable",
|
|
"CatchWarningsCtxManagerVariable",
|
|
"ClosureVariable",
|
|
"ConstantVariable",
|
|
"ConstDictVariable",
|
|
"ContextWrappingVariable",
|
|
"CountIteratorVariable",
|
|
"CUDADeviceVariable",
|
|
"CustomizedDictVariable",
|
|
"CycleIteratorVariable",
|
|
"DefaultDictVariable",
|
|
"DeletedVariable",
|
|
"DeterministicAlgorithmsVariable",
|
|
"EnumVariable",
|
|
"FakeItemVariable",
|
|
"GetAttrVariable",
|
|
"GradModeVariable",
|
|
"InspectSignatureVariable",
|
|
"IteratorVariable",
|
|
"ItertoolsVariable",
|
|
"LambdaVariable",
|
|
"LazyVariableTracker",
|
|
"ListIteratorVariable",
|
|
"ListVariable",
|
|
"NamedTupleVariable",
|
|
"NestedUserFunctionVariable",
|
|
"NewCellVariable",
|
|
"NewGlobalVariable",
|
|
"NNModuleVariable",
|
|
"NumpyNdarrayVariable",
|
|
"NumpyVariable",
|
|
"OptimizerVariable",
|
|
"PlacementVariable",
|
|
"PolyfilledFunctionVariable",
|
|
"PythonModuleVariable",
|
|
"RangeVariable",
|
|
"RegexPatternVariable",
|
|
"RemovableHandleVariable",
|
|
"RepeatIteratorVariable",
|
|
"RestrictedListSubclassVariable",
|
|
"SDPAParamsVariable",
|
|
"SkipFunctionVariable",
|
|
"SliceVariable",
|
|
"StringFormatVariable",
|
|
"SuperVariable",
|
|
"TensorVariable",
|
|
"TorchCtxManagerClassVariable",
|
|
"TorchInGraphFunctionVariable",
|
|
"TorchVersionVariable",
|
|
"TupleVariable",
|
|
"UnknownVariable",
|
|
"UnspecializedNNModuleVariable",
|
|
"UnspecializedPythonVariable",
|
|
"UntypedStorageVariable",
|
|
"UserDefinedClassVariable",
|
|
"UserDefinedObjectVariable",
|
|
"UserFunctionVariable",
|
|
"UserMethodVariable",
|
|
"VariableTracker",
|
|
"WithExitFunctionVariable",
|
|
]
|