mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Minor compile time optimizations in DefaultHandler (#146282)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146282 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255, #146257
This commit is contained in:
parent
06604c4ec1
commit
d35f6b2339
5 changed files with 42 additions and 7 deletions
0
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file → Executable file
0
benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh
Normal file → Executable file
|
|
@ -948,7 +948,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
||||||
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
|
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
|
||||||
)
|
)
|
||||||
|
|
||||||
def output(self, x0: OpVarT) -> None:
|
def output(self, *args: OpVarT) -> None:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"{type(self).__name__}: ops.output should not appear at codegen time"
|
f"{type(self).__name__}: ops.output should not appear at codegen time"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -368,7 +368,7 @@ class DtypePropagationOpsHandler:
|
||||||
) -> None:
|
) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def output(self, x: DTypeArg) -> None:
|
def output(self, *args: DTypeArg) -> None:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"{type(self).__name__}: ops.output should not appear here"
|
f"{type(self).__name__}: ops.output should not appear here"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -602,8 +602,8 @@ class LoopBodyBlock:
|
||||||
return var
|
return var
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def output(result):
|
def output(*result):
|
||||||
tracer.create_proxy("output", "output", (result,), {})
|
tracer.create_proxy("output", "output", result, {})
|
||||||
|
|
||||||
tracer = LightTracer()
|
tracer = LightTracer()
|
||||||
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
from io import StringIO
|
||||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
@ -729,7 +731,7 @@ class OpsHandler(Generic[T]):
|
||||||
) -> T:
|
) -> T:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def output(self, x0: T) -> None:
|
def output(self, *args: T) -> None:
|
||||||
"""This is a fake op used in analysis but not codegen"""
|
"""This is a fake op used in analysis but not codegen"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -755,7 +757,7 @@ class DefaultHandler(OpsHandler[Any]):
|
||||||
provide generic op behavior.
|
provide generic op behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target: name of the op, see OpHandler.target
|
name: name of the op, see OpHandler.{name}
|
||||||
args: positional args passed to the op
|
args: positional args passed to the op
|
||||||
kwargs: keyword args passed to the op
|
kwargs: keyword args passed to the op
|
||||||
|
|
||||||
|
|
@ -783,8 +785,41 @@ class DefaultHandler(OpsHandler[Any]):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _init_cls(cls):
|
def _init_cls(cls):
|
||||||
|
"""
|
||||||
|
Here we codegen many functions of the form:
|
||||||
|
|
||||||
|
def add(self, a, b):
|
||||||
|
return self._default('add', (a, b), {})
|
||||||
|
|
||||||
|
and install them in cls. This is the same as _call_default above,
|
||||||
|
but is about 1.2x faster since CPython varargs parsing is slow.
|
||||||
|
"""
|
||||||
|
code = StringIO()
|
||||||
for target in OP_NAMES:
|
for target in OP_NAMES:
|
||||||
setattr(cls, target, cls._call_default(target))
|
sig = inspect.signature(getattr(OpsHandler, target))
|
||||||
|
if all(
|
||||||
|
p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
and p.default is inspect.Parameter.empty
|
||||||
|
for p in sig.parameters.values()
|
||||||
|
):
|
||||||
|
self_arg, *args = sig.parameters.keys()
|
||||||
|
assert self_arg == "self"
|
||||||
|
code.write(
|
||||||
|
f"""
|
||||||
|
def {target}(self, {', '.join(args)}):
|
||||||
|
return self._default({target!r}, ({', '.join(args)}, ), {{}})
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
code.write("\n\n")
|
||||||
|
else:
|
||||||
|
# slower fallback for ops with default or variadic arguments
|
||||||
|
setattr(cls, target, cls._call_default(target))
|
||||||
|
|
||||||
|
ctx: dict[str, Any] = {}
|
||||||
|
exec(code.getvalue(), ctx)
|
||||||
|
for target, impl in ctx.items():
|
||||||
|
if target in OP_NAMES:
|
||||||
|
setattr(cls, target, impl)
|
||||||
|
|
||||||
|
|
||||||
DefaultHandler._init_cls()
|
DefaultHandler._init_cls()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue