[inductor] Minor compile time optimizations in DefaultHandler (#146282)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146282
Approved by: https://github.com/shunting314
ghstack dependencies: #146252, #146254, #146255, #146257
This commit is contained in:
Jason Ansel 2025-02-07 13:32:55 -08:00 committed by PyTorch MergeBot
parent 06604c4ec1
commit d35f6b2339
5 changed files with 42 additions and 7 deletions

View file

View file

@ -948,7 +948,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
) )
def output(self, x0: OpVarT) -> None: def output(self, *args: OpVarT) -> None:
raise AssertionError( raise AssertionError(
f"{type(self).__name__}: ops.output should not appear at codegen time" f"{type(self).__name__}: ops.output should not appear at codegen time"
) )

View file

@ -368,7 +368,7 @@ class DtypePropagationOpsHandler:
) -> None: ) -> None:
return None return None
def output(self, x: DTypeArg) -> None: def output(self, *args: DTypeArg) -> None:
raise AssertionError( raise AssertionError(
f"{type(self).__name__}: ops.output should not appear here" f"{type(self).__name__}: ops.output should not appear here"
) )

View file

@ -602,8 +602,8 @@ class LoopBodyBlock:
return var return var
@staticmethod @staticmethod
def output(result): def output(*result):
tracer.create_proxy("output", "output", (result,), {}) tracer.create_proxy("output", "output", result, {})
tracer = LightTracer() tracer = LightTracer()
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})

View file

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations from __future__ import annotations
import inspect
import itertools import itertools
import re import re
import warnings import warnings
from io import StringIO
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
from unittest.mock import patch from unittest.mock import patch
@ -729,7 +731,7 @@ class OpsHandler(Generic[T]):
) -> T: ) -> T:
raise NotImplementedError raise NotImplementedError
def output(self, x0: T) -> None: def output(self, *args: T) -> None:
"""This is a fake op used in analysis but not codegen""" """This is a fake op used in analysis but not codegen"""
raise NotImplementedError raise NotImplementedError
@ -755,7 +757,7 @@ class DefaultHandler(OpsHandler[Any]):
provide generic op behavior. provide generic op behavior.
Args: Args:
target: name of the op, see OpHandler.target name: name of the op, see OpHandler.{name}
args: positional args passed to the op args: positional args passed to the op
kwargs: keyword args passed to the op kwargs: keyword args passed to the op
@ -783,8 +785,41 @@ class DefaultHandler(OpsHandler[Any]):
@classmethod @classmethod
def _init_cls(cls): def _init_cls(cls):
"""
Here we codegen many functions of the form:
def add(self, a, b):
return self._default('add', (a, b), {})
and install them in cls. This is the same as _call_default above,
but is about 1.2x faster since CPython varargs parsing is slow.
"""
code = StringIO()
for target in OP_NAMES: for target in OP_NAMES:
setattr(cls, target, cls._call_default(target)) sig = inspect.signature(getattr(OpsHandler, target))
if all(
p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and p.default is inspect.Parameter.empty
for p in sig.parameters.values()
):
self_arg, *args = sig.parameters.keys()
assert self_arg == "self"
code.write(
f"""
def {target}(self, {', '.join(args)}):
return self._default({target!r}, ({', '.join(args)}, ), {{}})
""".strip()
)
code.write("\n\n")
else:
# slower fallback for ops with default or variadic arguments
setattr(cls, target, cls._call_default(target))
ctx: dict[str, Any] = {}
exec(code.getvalue(), ctx)
for target, impl in ctx.items():
if target in OP_NAMES:
setattr(cls, target, impl)
DefaultHandler._init_cls() DefaultHandler._init_cls()