mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Minor compile time optimizations in DefaultHandler (#146282)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146282 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255, #146257
This commit is contained in:
parent
06604c4ec1
commit
d35f6b2339
5 changed files with 42 additions and 7 deletions
0
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file → Executable file
0
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file → Executable file
|
|
@ -948,7 +948,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
|||
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
|
||||
)
|
||||
|
||||
def output(self, x0: OpVarT) -> None:
|
||||
def output(self, *args: OpVarT) -> None:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.output should not appear at codegen time"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -368,7 +368,7 @@ class DtypePropagationOpsHandler:
|
|||
) -> None:
|
||||
return None
|
||||
|
||||
def output(self, x: DTypeArg) -> None:
|
||||
def output(self, *args: DTypeArg) -> None:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.output should not appear here"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -602,8 +602,8 @@ class LoopBodyBlock:
|
|||
return var
|
||||
|
||||
@staticmethod
|
||||
def output(result):
|
||||
tracer.create_proxy("output", "output", (result,), {})
|
||||
def output(*result):
|
||||
tracer.create_proxy("output", "output", result, {})
|
||||
|
||||
tracer = LightTracer()
|
||||
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import itertools
|
||||
import re
|
||||
import warnings
|
||||
from io import StringIO
|
||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -729,7 +731,7 @@ class OpsHandler(Generic[T]):
|
|||
) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
def output(self, x0: T) -> None:
|
||||
def output(self, *args: T) -> None:
|
||||
"""This is a fake op used in analysis but not codegen"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -755,7 +757,7 @@ class DefaultHandler(OpsHandler[Any]):
|
|||
provide generic op behavior.
|
||||
|
||||
Args:
|
||||
target: name of the op, see OpHandler.target
|
||||
name: name of the op, see OpHandler.{name}
|
||||
args: positional args passed to the op
|
||||
kwargs: keyword args passed to the op
|
||||
|
||||
|
|
@ -783,8 +785,41 @@ class DefaultHandler(OpsHandler[Any]):
|
|||
|
||||
@classmethod
|
||||
def _init_cls(cls):
|
||||
"""
|
||||
Here we codegen many functions of the form:
|
||||
|
||||
def add(self, a, b):
|
||||
return self._default('add', (a, b), {})
|
||||
|
||||
and install them in cls. This is the same as _call_default above,
|
||||
but is about 1.2x faster since CPython varargs parsing is slow.
|
||||
"""
|
||||
code = StringIO()
|
||||
for target in OP_NAMES:
|
||||
setattr(cls, target, cls._call_default(target))
|
||||
sig = inspect.signature(getattr(OpsHandler, target))
|
||||
if all(
|
||||
p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
and p.default is inspect.Parameter.empty
|
||||
for p in sig.parameters.values()
|
||||
):
|
||||
self_arg, *args = sig.parameters.keys()
|
||||
assert self_arg == "self"
|
||||
code.write(
|
||||
f"""
|
||||
def {target}(self, {', '.join(args)}):
|
||||
return self._default({target!r}, ({', '.join(args)}, ), {{}})
|
||||
""".strip()
|
||||
)
|
||||
code.write("\n\n")
|
||||
else:
|
||||
# slower fallback for ops with default or variadic arguments
|
||||
setattr(cls, target, cls._call_default(target))
|
||||
|
||||
ctx: dict[str, Any] = {}
|
||||
exec(code.getvalue(), ctx)
|
||||
for target, impl in ctx.items():
|
||||
if target in OP_NAMES:
|
||||
setattr(cls, target, impl)
|
||||
|
||||
|
||||
DefaultHandler._init_cls()
|
||||
|
|
|
|||
Loading…
Reference in a new issue