[inductor] Minor compile time optimizations in DefaultHandler (#146282)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146282
Approved by: https://github.com/shunting314
ghstack dependencies: #146252, #146254, #146255, #146257
This commit is contained in:
Jason Ansel 2025-02-07 13:32:55 -08:00 committed by PyTorch MergeBot
parent 06604c4ec1
commit d35f6b2339
5 changed files with 42 additions and 7 deletions

View file

View file

@ -948,7 +948,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
)
def output(self, x0: OpVarT) -> None:
def output(self, *args: OpVarT) -> None:
raise AssertionError(
f"{type(self).__name__}: ops.output should not appear at codegen time"
)

View file

@ -368,7 +368,7 @@ class DtypePropagationOpsHandler:
) -> None:
return None
def output(self, x: DTypeArg) -> None:
def output(self, *args: DTypeArg) -> None:
raise AssertionError(
f"{type(self).__name__}: ops.output should not appear here"
)

View file

@ -602,8 +602,8 @@ class LoopBodyBlock:
return var
@staticmethod
def output(result):
tracer.create_proxy("output", "output", (result,), {})
def output(*result):
tracer.create_proxy("output", "output", result, {})
tracer = LightTracer()
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})

View file

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import inspect
import itertools
import re
import warnings
from io import StringIO
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
from unittest.mock import patch
@ -729,7 +731,7 @@ class OpsHandler(Generic[T]):
) -> T:
raise NotImplementedError
def output(self, x0: T) -> None:
def output(self, *args: T) -> None:
"""This is a fake op used in analysis but not codegen"""
raise NotImplementedError
@ -755,7 +757,7 @@ class DefaultHandler(OpsHandler[Any]):
provide generic op behavior.
Args:
target: name of the op, see OpHandler.target
name: name of the op, see OpHandler.{name}
args: positional args passed to the op
kwargs: keyword args passed to the op
@ -783,8 +785,41 @@ class DefaultHandler(OpsHandler[Any]):
@classmethod
def _init_cls(cls):
"""
Here we codegen many functions of the form:
def add(self, a, b):
return self._default('add', (a, b), {})
and install them in cls. This is the same as _call_default above,
but is about 1.2x faster since CPython varargs parsing is slow.
"""
code = StringIO()
for target in OP_NAMES:
setattr(cls, target, cls._call_default(target))
sig = inspect.signature(getattr(OpsHandler, target))
if all(
p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and p.default is inspect.Parameter.empty
for p in sig.parameters.values()
):
self_arg, *args = sig.parameters.keys()
assert self_arg == "self"
code.write(
f"""
def {target}(self, {', '.join(args)}):
return self._default({target!r}, ({', '.join(args)}, ), {{}})
""".strip()
)
code.write("\n\n")
else:
# slower fallback for ops with default or variadic arguments
setattr(cls, target, cls._call_default(target))
ctx: dict[str, Any] = {}
exec(code.getvalue(), ctx)
for target, impl in ctx.items():
if target in OP_NAMES:
setattr(cls, target, impl)
DefaultHandler._init_cls()