diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh old mode 100644 new mode 100755 diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index dbd02188665..fec37fb6002 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -948,7 +948,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" ) - def output(self, x0: OpVarT) -> None: + def output(self, *args: OpVarT) -> None: raise AssertionError( f"{type(self).__name__}: ops.output should not appear at codegen time" ) diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index 5b45943b940..256079c8071 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -368,7 +368,7 @@ class DtypePropagationOpsHandler: ) -> None: return None - def output(self, x: DTypeArg) -> None: + def output(self, *args: DTypeArg) -> None: raise AssertionError( f"{type(self).__name__}: ops.output should not appear here" ) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index afee8988253..4968544d80f 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -602,8 +602,8 @@ class LoopBodyBlock: return var @staticmethod - def output(result): - tracer.create_proxy("output", "output", (result,), {}) + def output(*result): + tracer.create_proxy("output", "output", result, {}) tracer = LightTracer() proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 5338372f6af..0118d29368c 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -1,9 +1,11 @@ # mypy: allow-untyped-defs from __future__ import annotations +import inspect import itertools import re import warnings +from io import StringIO from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from unittest.mock import patch @@ -729,7 +731,7 @@ class OpsHandler(Generic[T]): ) -> T: raise NotImplementedError - def output(self, x0: T) -> None: + def output(self, *args: T) -> None: """This is a fake op used in analysis but not codegen""" raise NotImplementedError @@ -755,7 +757,7 @@ class DefaultHandler(OpsHandler[Any]): provide generic op behavior. Args: - target: name of the op, see OpHandler.target + name: name of the op, see OpHandler.{name} args: positional args passed to the op kwargs: keyword args passed to the op @@ -783,8 +785,41 @@ class DefaultHandler(OpsHandler[Any]): @classmethod def _init_cls(cls): + """ + Here we codegen many functions of the form: + + def add(self, a, b): + return self._default('add', (a, b), {}) + + and install them in cls. This is the same as _call_default above, + but is about 1.2x faster since CPython varargs parsing is slow. + """ + code = StringIO() for target in OP_NAMES: - setattr(cls, target, cls._call_default(target)) + sig = inspect.signature(getattr(OpsHandler, target)) + if all( + p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is inspect.Parameter.empty + for p in sig.parameters.values() + ): + self_arg, *args = sig.parameters.keys() + assert self_arg == "self" + code.write( + f""" + def {target}(self, {', '.join(args)}): + return self._default({target!r}, ({', '.join(args)}, ), {{}}) + """.strip() + ) + code.write("\n\n") + else: + # slower fallback for ops with default or variadic arguments + setattr(cls, target, cls._call_default(target)) + + ctx: dict[str, Any] = {} + exec(code.getvalue(), ctx) + for target, impl in ctx.items(): + if target in OP_NAMES: + setattr(cls, target, impl) DefaultHandler._init_cls()