From 46e83bb6377ad11c475fafc93c9ea15433056573 Mon Sep 17 00:00:00 2001 From: cyyever Date: Sat, 8 Feb 2025 07:19:37 +0000 Subject: [PATCH 01/28] Fix linter F821 error (#146665) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/146665 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan --- test/test_sort_and_select.py | 2 +- test/test_transformers.py | 4 ++-- torch/testing/_internal/jit_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 8c1f1b12f36..daa39964374 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -50,7 +50,7 @@ class TestSortAndSelect(TestCase): return ((b != b) | (a <= b)).all().item() else: - error( # noqa: F821 + raise ValueError( f'unknown order "{order}", must be "ascending" or "descending"' ) diff --git a/test/test_transformers.py b/test/test_transformers.py index eab1cb8a605..af711a6fb67 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -324,7 +324,7 @@ class TestTransformers(NNTestCase): encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8)) except AssertionError: continue - self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821 + self.assertFalse(e, "Failed to catch unsupported uint8 type exception") test_train_bool = encoder(test, src_key_padding_mask=pad_mask) encoder.eval() @@ -335,7 +335,7 @@ class TestTransformers(NNTestCase): encoder(test, src_key_padding_mask=pad_mask.to(torch.int64)) except AssertionError as e: continue - self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821 + self.assertFalse(e, "Failed to catch unsupported Long type exception") test_eval_bool = encoder(test, src_key_padding_mask=pad_mask) l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 06a1c2bd5d4..299eb999676 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -505,7 +505,7 @@ class JitTestCase(JitCommonTestCase): script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout(): opt_script_outputs = scripted_fn(*recording_inputs) - with self.capture_stdout() as _python_stdout: + with self.capture_stdout(): python_outputs = python_fn(*inputs) if not IS_WINDOWS: self.assertExpected(script_stdout[0], subname='stdout') From 71498aeae3b22b7477331a0c8aef3a25f9da314f Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Feb 2025 13:32:53 -0800 Subject: [PATCH 02/28] [inductor] Refactor op handlers part 2 (#146252) This replaces the `__getattr__()` pattern used in (some) OpHandlers with a `DefaultHandler` class that has an implementation of every op that calls `self._default()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146252 Approved by: https://github.com/yanboliang --- torch/_inductor/codegen/common.py | 17 +- torch/_inductor/dtype_propagation.py | 10 ++ torch/_inductor/loop_body.py | 2 +- torch/_inductor/ops_handler.py | 232 ++++++++++++++++----------- torch/_inductor/select_algorithm.py | 2 +- torch/_inductor/subgraph_lowering.py | 2 +- torch/_inductor/virtualized.py | 6 +- 7 files changed, 169 insertions(+), 102 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index c254aacdb17..6e45c02b68e 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -948,6 +948,16 @@ class OpOverrides(BasicMathOps, OpDecompositions): f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" ) + def output(self, x0: OpVarT) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear at codegen time" + ) + + def placeholder(self, index: int) -> OpVarT: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear at codegen time" + ) + @staticmethod def _unimplemented(name: str) -> Callable[..., OpVarT]: def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT: @@ -2570,6 +2580,7 @@ class CSEProxy: ) -# Use mypy to check protocol implemented correctly -def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: - return h +if TYPE_CHECKING: + + class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]): + pass diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index efe0ebe2caf..5b45943b940 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -368,6 +368,16 @@ class DtypePropagationOpsHandler: ) -> None: return None + def output(self, x: DTypeArg) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear here" + ) + + def placeholder(self, index: int) -> torch.dtype: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear here" + ) + if TYPE_CHECKING: diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 21f63a11b67..aab9d318762 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -449,7 +449,7 @@ class LoopBodyBlock: ) class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] - self.name = "CaptureIndexing" + name = "CaptureIndexing" def load(self, name: str, index: sympy.Expr): index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 935c5f6fc36..22ce7154c6b 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -3,7 +3,17 @@ from __future__ import annotations import itertools import re -from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union +import warnings +from typing import ( + Any, + Callable, + Literal, + NamedTuple, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) from typing_extensions import Protocol from unittest.mock import patch @@ -752,6 +762,14 @@ class OpsHandler(Protocol[T]): ) -> T: ... + def output(self, x0: T) -> None: + """This is a fake op used in analysis but not codegen""" + ... + + def placeholder(self, index: int) -> T: + """This is a fake op used in analysis but not codegen""" + ... + _ignore_op_re = re.compile(r"_.*|paren").fullmatch @@ -763,15 +781,53 @@ def list_ops(cls: type[Any]): OP_NAMES = list_ops(OpsHandler) -def _return_none(*args, **kwargs): - return None +class DefaultHandler: + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + """ + Default implementation for all ops. Override in a subclass to + provide generic op behavior. + + Args: + target: name of the op, see OpHandler.target + args: positional args passed to the op + kwargs: keyword args passed to the op + + Returns: + return value of the op + + """ + raise NotImplementedError + + def __getattr__(self, name: str) -> Any: + def fallback(*args: Any, **kwargs: Any) -> Any: + return self._default(name, args, kwargs) + + # would like to remove this function entirely, but it's used in MTIA backend + warnings.warn(f"undefined OpHandler.{name}, please add missing op schema") + return fallback + + @staticmethod + def _call_default(target: str): + def call_default(self, *args, **kwargs): + return self._default(target, args, kwargs) + + call_default.__name__ = target + return call_default + + @classmethod + def _init_cls(cls): + for target in OP_NAMES: + setattr(cls, target, cls._call_default(target)) -class NoopHandler: +DefaultHandler._init_cls() + + +class NoopHandler(DefaultHandler): name = "NoopHandler" - def __getattr__(self, name: str) -> Callable[..., None]: - return _return_none + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return None @staticmethod def masked(mask, body, other) -> None: @@ -794,9 +850,10 @@ class NoopHandler: return sympy.S.Zero -# Use mypy to check protocol implemented correctly -def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]: - return h +if TYPE_CHECKING: + + class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]): + pass # mypy will error if we got any of the signatures wrong class BasicMathOps: @@ -878,16 +935,14 @@ class BasicMathOps: return f"-{a}" -class MockHandler(BasicMathOps): +class MockHandler(BasicMathOps, DefaultHandler): name = "MockHandler" - def __getattr__(self, name): - def inner(*args, **kwargs): - fargs = [_arg_str(a) for a in args] - fargs.extend(f"{k}={v}" for k, v in kwargs.items()) - return f"ops.{name}({', '.join(fargs)})" - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + fargs = [*map(_arg_str, args)] + for k, v in kwargs.items(): + fargs.append(f"{k}={_arg_str(v)}") + return f"ops.{name}({', '.join(fargs)})" @staticmethod def masked(mask, body, other) -> str: @@ -916,15 +971,16 @@ class MockHandler(BasicMathOps): return sympy_index_symbol(str(index_var)) -# Use mypy to check protocol implemented correctly -def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: - return h +if TYPE_CHECKING: + + class _typecheck_MockHandler(MockHandler, OpsHandler[str]): + pass # mypy will error if we got any of the signatures wrong -class KernelFormatterHandler: +class KernelFormatterHandler(DefaultHandler): def __init__(self, parent_handler): self.parent_handler = parent_handler - self.output = IndentedBuffer(1) + self._output = IndentedBuffer(1) self.var_counter = itertools.count() @staticmethod @@ -936,8 +992,8 @@ class KernelFormatterHandler: names = ["index", "rindex"] if rindex is not None else ["index"] formatter = KernelFormatterHandler(MockHandler()) - with formatter.output.indent(-1): - formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + with formatter._output.indent(-1): + formatter._output.writeline(f"def inner_fn({', '.join(names)}):") for name, arg in zip(names, args): if arg: lhs = ", ".join( @@ -946,7 +1002,7 @@ class KernelFormatterHandler: for v in arg ] ) - formatter.output.writeline(f"{lhs} = {name}") + formatter._output.writeline(f"{lhs} = {name}") with V.set_ops_handler(formatter), patch.object( FlexibleLayout, "allow_indexing", True @@ -954,21 +1010,19 @@ class KernelFormatterHandler: result = ir_fn(*args) return formatter.getvalue(result) - def __getattr__(self, name) -> Callable[..., Any]: - def inner(*args, **kwargs): - line = getattr(self.parent_handler, name)(*args, **kwargs) - if name == "indirect_indexing": - return line + def indirect_indexing(self, *args, **kwargs) -> sympy.Symbol: + return self.parent_handler.indirect_indexing(*args, **kwargs) - def write(line): - # replace line with a new variable name - varname = f"tmp{next(self.var_counter)}" - self.output.writeline(f"{varname} = {line}") - return varname + def _write(self, line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self._output.writeline(f"{varname} = {line}") + return varname - return pytree.tree_map(write, line) - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return pytree.tree_map( + self._write, getattr(self.parent_handler, name)(*args, **kwargs) + ) def reduction( self, @@ -980,44 +1034,34 @@ class KernelFormatterHandler: line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) num_values = reduction_num_outputs(reduction_type) varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] - self.output.writeline(f"{','.join(varnames)} = {line}") + self._output.writeline(f"{','.join(varnames)} = {line}") return tuple(varnames) if num_values > 1 else varnames[0] def getvalue(self, result): - self.output.writeline(f"return {result}") - return self.output.getvalue() + self._output.writeline(f"return {result}") + return self._output.getvalue() -# Use mypy to check protocol implemented correctly -def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: - return h +if TYPE_CHECKING: + + class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]): + pass # mypy will error if we got any of the signatures wrong -class WrapperHandler(Generic[T]): +class WrapperHandler(DefaultHandler): def __init__(self, inner: Any): self._inner = inner - def __getattr__(self, item): - return getattr(self._inner, item) + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return getattr(self._inner, name)(*args, **kwargs) -# Use mypy to check protocol implemented correctly -def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: - return h - - -class AddParenHandler(WrapperHandler[T]): - def __getattr__(self, name): - def inner(*args, **kwargs): - val = getattr(self._inner, name)(*args, **kwargs) - return f"({val})" - - return inner - - -# Use mypy to check protocol implemented correctly -def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: - return h +class AddParenHandler(WrapperHandler): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + val = getattr(self._inner, name)(*args, **kwargs) + if not val or isinstance(val, (sympy.Expr, tuple, list)): + return val + return f"({val})" class OpCountResult(NamedTuple): @@ -1027,7 +1071,7 @@ class OpCountResult(NamedTuple): nontrivial_read_count: int -class OpCounterCSE: +class OpCounterCSE(DefaultHandler): """Shim to count how many ops are used""" def __init__(self, inner): @@ -1035,18 +1079,15 @@ class OpCounterCSE: self.parent_handler = inner self.op_count = 0 self.var_names = {} - self._used_ops = OrderedSet[str]() + self._used_ops: OrderedSet[str] = OrderedSet() self._read_names: list[str] = [] self._nontrivial_read_count = 0 - def __getattr__(self, name): - def inner(*args, **kwargs): - return pytree.tree_map( - self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) - ) - + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: self._used_ops.add(name) - return inner + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) def _update_count(self, val): varname = self.var_names.get(val) @@ -1111,8 +1152,10 @@ class OpCounterCSE: ) -def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: - return h +if TYPE_CHECKING: + + class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]): + pass # mypy will error if we got any of the signatures wrong class ExtractConstantsHandler(NoopHandler): @@ -1125,44 +1168,45 @@ class ExtractConstantsHandler(NoopHandler): return ir.Constant(value=value, dtype=dtype, device=self.device) -def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]: - return h +if TYPE_CHECKING: + + class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]): + pass # mypy will error if we got any of the signatures wrong -class SimpleCSEHandler(WrapperHandler[T]): +class SimpleCSEHandler(WrapperHandler): """Wraps the underlying handler with a CSE pass NOTE: Compared to codegen level CSE this is simplified as it doesn't support stores which require load cache invalidation. """ - def __init__(self, inner: OpsHandler[T]): + def __init__(self, inner: Any): super().__init__(inner) - self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {} + self.cse_cache: dict[str, Union[Any, tuple[Any, ...]]] = {} self.mock = MockHandler() def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] - def store(self, *args, **kwargs) -> T: + def store(self, *args, **kwargs) -> None: raise NotImplementedError("store not implemented") - def store_reduction(self, *args, **kwargs) -> T: + def store_reduction(self, *args, **kwargs) -> None: raise NotImplementedError("store not implemented") - def __getattr__(self, name) -> Callable[..., Any]: - def inner(*args, **kwargs): - key = getattr(self.mock, name)(*args, **kwargs) - val = self.cse_cache.get(key) - if val is not None: - return val - - val = getattr(self._inner, name)(*args, **kwargs) - self.cse_cache[key] = val + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + key = getattr(self.mock, name)(*args, **kwargs) + val = self.cse_cache.get(key) + if val is not None: return val - return inner + val = getattr(self._inner, name)(*args, **kwargs) + self.cse_cache[key] = val + return val -def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]: - return h +if TYPE_CHECKING: + + class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]): + pass # mypy will error if we got any of the signatures wrong diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bd96e830bcc..d016af99954 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -742,7 +742,7 @@ class TritonTemplateKernel(TritonKernel): template_mask = self.template_mask class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = name + name = "StoreOutputSubstitution" def store( self, diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index ce35959c532..0166534e5fb 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -135,7 +135,7 @@ class InputDescriptor: device: torch.device -class TracingOpsHandler(WrapperHandler[T]): +class TracingOpsHandler(WrapperHandler): def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: parent = tracer.create_proxy("placeholder", "ops", (), {}) super().__init__(parent) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 4de41846166..55ff6ce32b3 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -59,7 +59,7 @@ from __future__ import annotations from contextlib import AbstractContextManager, contextmanager from threading import local -from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union from torch.utils._ordered_set import OrderedSet @@ -154,7 +154,9 @@ class NullKernelHandler(NullHandler): self.index_dtype = "tl.int64" -_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) +_ops: Virtualized[OpsHandler[Any]] = Virtualized( + "ops", cast(type[OpsHandler[Any]], MockHandler) +) _graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) _real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) _fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) From 0e31e5932b5558fb80f7ca77a90ea2f9b4a14d45 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Feb 2025 13:32:53 -0800 Subject: [PATCH 03/28] [inductor] Refactor op handlers part 3 (#146254) Fixes type errors that arise from typing `V.ops`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146254 Approved by: https://github.com/shunting314 ghstack dependencies: #146252 --- torch/_inductor/ir.py | 60 +++++++++++++++++++--------------- torch/_inductor/lowering.py | 8 +++-- torch/_inductor/virtualized.py | 4 ++- 3 files changed, 42 insertions(+), 30 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d90398ca043..6a800d9e81d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -75,7 +75,7 @@ from .dependencies import ( var_builder, ) from .loop_body import LoopBody -from .ops_handler import OpCounterCSE, OpCountResult +from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode from .runtime.benchmarking import benchmarker from .runtime.hints import DeviceProperties, ReductionHint from .utils import ( @@ -916,9 +916,9 @@ class Pointwise(Loops): output_name: Optional[str], indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], - ) -> OpsValue: + ) -> None: loader = self.make_loader() - return ops.store(output_name, indexer(vars), loader(vars)) + return ops.store(output_name or "unnamed", indexer(vars), loader(vars)) def constant_to_device(self, device: torch.device) -> IRNode: """Move this to a given device. Requires that all reads are to constants.""" @@ -932,7 +932,7 @@ class Pointwise(Loops): @ir_dataclass class Scatter(Pointwise): output_indexer: Callable[[Sequence[Expr]], Expr] - scatter_mode: Optional[str] = None + scatter_mode: StoreMode = None def constant_to_device(self, device: torch.device) -> IRNode: """Move this to a given device. Requires that all reads are to constants.""" @@ -952,8 +952,10 @@ class Scatter(Pointwise): output_name: Optional[str], indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], - ) -> OpsValue: + ) -> None: loader = self.make_loader() + if output_name is None: + output_name = "unnamed" return ops.store( output_name, indexer(self.output_indexer(vars)), @@ -1038,7 +1040,7 @@ def get_reduction_combine_fn( @ir_dataclass class Reduction(Loops): reduction_ranges: Sequence[_IntLike] - reduction_type: str + reduction_type: ReductionType # self.dtype represents the dst dtype src_dtype: torch.dtype reduction_hint: ReductionHint @@ -1065,14 +1067,14 @@ class Reduction(Loops): indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], reduction_vars: Sequence[Symbol], - ) -> OpsValue: + ) -> None: value = ops.reduction( self.dtype, self.src_dtype, self.reduction_type, self.inner_fn(vars, reduction_vars), ) - return ops.store_reduction(output_name, indexer(vars), value) + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) def index_length(self) -> int: return len(self.ranges) + len(self.reduction_ranges) @@ -1110,7 +1112,7 @@ class Reduction(Loops): inner_fn: Callable[..., OpsValue], ranges: Sequence[_IntLike], reduction_ranges: Sequence[_IntLike], - reduction_type: str, + reduction_type: Union[ReductionType, Literal["scan"]], reduction_numel: Expr, input_node: Optional[IRNode] = None, ) -> tuple[ReductionHint, _IntLike]: @@ -1196,7 +1198,7 @@ class Reduction(Loops): inner_fn=inner_fn, ranges=ranges, reduction_ranges=reduction_ranges, - reduction_type=reduction_type, + reduction_type=reduction_type if reduction_type != "scan" else "sum", src_dtype=src_dtype, reduction_hint=ReductionHint.DEFAULT, ) @@ -1323,7 +1325,7 @@ class Reduction(Loops): inner_fn: Callable[..., Any], ranges: Sequence[Expr], reduction_ranges: Sequence[Expr], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, ) -> TensorBox: @@ -1593,7 +1595,7 @@ class Reduction(Loops): original_reduction_ranges: Sequence[Expr], new_ranges: list[Expr], new_reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, ) -> TensorBox: @@ -1655,7 +1657,7 @@ class Reduction(Loops): inner_fn: Callable[..., Any], ranges: Sequence[Expr], reduction_ranges: Sequence[Expr], - reduction_type: str, + reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, ) -> TensorBox: @@ -1696,7 +1698,7 @@ class Reduction(Loops): original_reduction_ranges: Sequence[Expr], new_ranges: list[Integer], new_reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint, ) -> TensorBox: """ @@ -1735,7 +1737,7 @@ class WelfordReduction(Reduction): inner_fns: Sequence[Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]], ranges: Sequence[Integer], reduction_ranges: Sequence[Integer], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint, output_index: int, ) -> None: @@ -1767,7 +1769,7 @@ class WelfordReduction(Reduction): indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], reduction_vars: Sequence[Symbol], - ) -> OpsValue: + ) -> None: values = ops.reduction( self.dtype, self.src_dtype, @@ -1775,7 +1777,7 @@ class WelfordReduction(Reduction): self.inner_fn(vars, reduction_vars), ) value = values[self.output_index] - return ops.store_reduction(output_name, indexer(vars), value) + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) @classmethod def create( # type: ignore[override] @@ -1785,7 +1787,7 @@ class WelfordReduction(Reduction): inner_fns: Sequence[Callable[..., Any]], ranges: list[Integer], reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, ) -> Sequence[TensorBox]: assert reduction_type in ("welford_reduce", "welford_combine") @@ -1911,7 +1913,7 @@ class WelfordReduction(Reduction): inner_fns: Sequence[Callable[..., Any]], ranges: list[Integer], reduction_ranges: list[Integer], - reduction_type: str, + reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, ) -> Sequence[TensorBox]: @@ -2031,11 +2033,13 @@ class Scan(Loops): indexer: Callable[[Sequence[_IntLike]], Never], vars: Sequence[Expr], scan_vars: Sequence[Symbol], - ) -> OpsValue: + ) -> None: idx = self.reindex(vars, scan_vars) - values = [inner_fn(idx) for inner_fn in self.inner_fns] + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.scan(self.dtypes, self.combine_fn, values) - return ops.store(output_name, indexer(idx), result[self.output_index]) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) def get_reduction_type(self) -> Optional[str]: # return self.scan_op @@ -2229,11 +2233,13 @@ class Sort(Loops): indexer: Callable[[Sequence[Expr]], Expr], vars: Sequence[Expr], reduction_vars: Sequence[Expr], - ) -> OpsValue: + ) -> None: idx = self.reindex(vars, reduction_vars) - values = [inner_fn(idx) for inner_fn in self.inner_fns] + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.sort(self.dtypes, values, self.stable, self.descending) - return ops.store(output_name, indexer(idx), result[self.output_index]) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) def get_reduction_type(self) -> Optional[str]: return "sort" @@ -3790,7 +3796,7 @@ class Buffer(IRNode): def loader(index): # type: ignore[no-untyped-def] indexer = self.make_indexer() - return ops.load(self.name, indexer(index)) + return ops.load(self.name or "unnamed", indexer(index)) return loader @@ -3983,7 +3989,7 @@ class ComputedBuffer(OperationBuffer): return self.data.make_loader() return super().make_loader() - def get_store_function(self) -> Callable[..., OpsValue]: + def get_store_function(self) -> Callable[..., None]: indexer = self.get_layout().as_fixed().make_indexer() if isinstance(self.data, (Reduction, Scan, Sort)): return partial(self.data.store_reduction, self.name, indexer) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index eacf4dbb3d0..719a49312f4 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -12,7 +12,7 @@ import os import warnings from collections import defaultdict from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec from unittest.mock import patch @@ -81,6 +81,10 @@ from .utils import ( from .virtualized import ops, V +if TYPE_CHECKING: + from .ops_handler import ReductionType + + _T = TypeVar("_T") _P = ParamSpec("_P") @@ -5633,7 +5637,7 @@ def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): ) -def make_reduction(reduction_type: str, override_return_dtype=None): +def make_reduction(reduction_type: ReductionType, override_return_dtype=None): def inner(x, axis=None, keepdims=False, *, dtype=None): kwargs = _make_reduction_inner( x, diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 55ff6ce32b3..d3c9725f59a 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -308,7 +308,9 @@ class OpsWrapper: return _ops.indirect_indexing(index, size, check, wrap_neg) -ops = OpsWrapper() +# we lie about the type of ops so the rest of the codebase typecheck properly +# DefaultHandler implements the OpsHandler protocol via metaprogramming +ops = cast(OpsHandler[Any], OpsWrapper()) class _V: From 403db2faee8cb93a15e9b3261952005c54f28010 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Feb 2025 13:32:54 -0800 Subject: [PATCH 04/28] [inductor] Refactor op handlers part 4 (#146255) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This replaces the `__getattr__()` pattern used in remaining OpHandlers with a `DefaultHandler` class defined in part 2. Some compile time wins from this as well: ``` 2025-02-02T19:46:32.2033010Z 2025-02-02T19:46:32.2036607Z WIN: benchmark ('add_loop_inductor', 'compile_time_instruction_count') failed, actual result 29633182927 is -1.71% lower than expected 30150000000 ±1.50% please update the expected results. 2025-02-02T19:46:32.2037575Z 2025-02-02T19:46:32.2037907Z please update all results that changed significantly, and not only the failed ones 2025-02-02T19:46:32.2039291Z PASS: benchmark ('add_loop_inductor_dynamic_gpu', 'compile_time_instruction_count') pass, actual result 43986879172 -1.02% is within expected 44440000000 ±2.50% 2025-02-02T19:46:32.2040131Z 2025-02-02T19:46:32.2041180Z WIN: benchmark ('add_loop_inductor_gpu', 'compile_time_instruction_count') failed, actual result 26246225695 is -1.85% lower than expected 26740000000 ±1.50% please update the expected results. 2025-02-02T19:46:32.2042188Z ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146255 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254 --- .../pr_time_benchmarks/expected_results.csv | 10 +- .../_inductor/analyze_preserves_zero_mask.py | 62 ++++++----- torch/_inductor/codegen/common.py | 105 +++++++++--------- torch/_inductor/codegen/triton.py | 32 +++--- torch/_inductor/dependencies.py | 36 +++--- torch/_inductor/index_propagation.py | 45 ++++---- torch/_inductor/loop_body.py | 7 +- torch/_inductor/virtualized.py | 14 +-- 8 files changed, 164 insertions(+), 147 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 388b8d1a5f6..c80c46dc1e7 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025 -add_loop_inductor,compile_time_instruction_count,30150000000,0.015 +add_loop_inductor,compile_time_instruction_count,29630000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015 @@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000, -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015 @@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000, -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015 diff --git a/torch/_inductor/analyze_preserves_zero_mask.py b/torch/_inductor/analyze_preserves_zero_mask.py index a03439c2bae..abdf1320bc2 100644 --- a/torch/_inductor/analyze_preserves_zero_mask.py +++ b/torch/_inductor/analyze_preserves_zero_mask.py @@ -1,6 +1,6 @@ import dataclasses import itertools -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import sympy @@ -9,6 +9,7 @@ from torch._inductor import config from torch._inductor.dtype_propagation import DtypePropagationOpsHandler from torch._inductor.index_propagation import SymPyOps, TypedExpr +from .ops_handler import DefaultHandler, OpsHandler from .virtualized import StoreMode, V @@ -20,7 +21,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol: return sympy.Symbol(f"unknown_{count}") -class PreservesZeros(SymPyOps): +class PreservesZeros(SymPyOps, DefaultHandler): """ For prologue kernels where the loads are masked, does the final store of this kernel preserve the zeros. @@ -54,18 +55,20 @@ class PreservesZeros(SymPyOps): self = V.get_ops_handler() return construct_symbol(next(self.count), torch.int32) - def __getattr__(self, name: str) -> Callable[..., Any]: + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: from torch._inductor.codegen.common import OpDecompositions - def inner(*args: Any, **kwargs: Any) -> TypedExpr: - if hasattr(OpDecompositions, name): - return getattr(OpDecompositions, name)(*args, **kwargs).value + if hasattr(OpDecompositions, name): + return getattr(OpDecompositions, name)(*args, **kwargs).value - nonlocal self - dtype = getattr(self.dtype_prop, name)(*args, **kwargs) - return TypedExpr(construct_symbol(next(self.count), dtype), dtype) + dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + return TypedExpr(construct_symbol(next(self.count), dtype), dtype) - return inner + +if TYPE_CHECKING: + + class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]): + pass def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool: @@ -88,7 +91,7 @@ class DTypeContainer: is_scalar: bool = False -class RecordLowPrecisionOps: +class RecordLowPrecisionOps(DefaultHandler): def __init__(self) -> None: self.low_precision_numeric_op = False self.dtype_prop = DtypePropagationOpsHandler() @@ -111,28 +114,31 @@ class RecordLowPrecisionOps: def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: return sympy.S.Zero - def __getattr__(self, name: str) -> Callable[..., Any]: - def low_prec_float(dtype: torch.dtype) -> bool: - return dtype.is_floating_point and dtype.itemsize < 4 + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) + if name == "constant": + out = DTypeContainer(torch.float, is_scalar=True) - def inner(*args: Any, **kwargs: Any) -> DTypeContainer: - out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) - out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) - if name == "constant": - out = DTypeContainer(torch.float, is_scalar=True) + uses_low_prec = any( + isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype) + for dtype_cont in itertools.chain((out,), args, kwargs.values()) + ) - uses_low_prec = any( - isinstance(dtype_cont, DTypeContainer) - and low_prec_float(dtype_cont.dtype) - for dtype_cont in itertools.chain((out,), args, kwargs.values()) - ) + if uses_low_prec and name not in self.non_numeric_ops: + self.low_precision_numeric_op = True - if uses_low_prec and name not in self.non_numeric_ops: - self.low_precision_numeric_op = True + return out - return out - return inner +if TYPE_CHECKING: + + class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]): + pass + + +def low_prec_float(dtype: torch.dtype) -> bool: + return dtype.is_floating_point and dtype.itemsize < 4 def can_codegen_without_upcasts( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 6e45c02b68e..ab5280e1d5b 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -41,7 +41,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val from .. import config, metrics from ..dtype_propagation import DtypePropagationOpsHandler -from ..ops_handler import BasicMathOps +from ..ops_handler import BasicMathOps, DefaultHandler from ..utils import ( boolean_ops, DeferredLineBase, @@ -2263,7 +2263,7 @@ class KernelTemplate: raise NotImplementedError -class CSEProxy: +class CSEProxy(DefaultHandler): name = "CSEProxy" def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): @@ -2272,69 +2272,66 @@ class CSEProxy: self.kernel = kernel self.parent_handler = parent_handler - def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] - def inner(*args: Any, **kwargs: Any) -> CSEVariable: - bounds = self._bound_variable(name, *args, **kwargs) + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + bounds = self._bound_variable(name, *args, **kwargs) - value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type] - dtype_handler = DtypePropagationOpsHandler() + value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + dtype_handler = DtypePropagationOpsHandler() - output_idx = 0 + output_idx = 0 - def do_cse(v: str) -> CSEVariable: - # cpp backend doesnt set current device - TODO: fix - if V.graph.current_device is not None: - device_str = V.graph.get_current_device_or_throw().type - triton_backend = ( - config.cpu_backend == "triton" - if device_str == "cpu" - else config.cuda_backend == "triton" - if device_str != "mps" - else False - ) - else: - triton_backend = False - - # only triton backend tracks dtype currently - if triton_backend: - if name == "masked": - output_dtype = value.dtype - else: - output_dtype = getattr( - dtype_handler, - name, - )(*args, **kwargs) - else: - # cpp backend doesnt track dtype yet - output_dtype = None - - csevar = V.kernel.cse.generate( - V.kernel.compute, - v, - bounds=bounds, - dtype=output_dtype, + def do_cse(v: str) -> CSEVariable: + # cpp backend doesnt set current device - TODO: fix + if V.graph.current_device is not None: + device_str = V.graph.get_current_device_or_throw().type + triton_backend = ( + config.cpu_backend == "triton" + if device_str == "cpu" + else config.cuda_backend == "triton" + if device_str != "mps" + else False ) + else: + triton_backend = False - nonlocal output_idx - if config.test_configs.runtime_triton_dtype_assert and triton_backend: - from torch._inductor.codegen.triton import triton_type + # only triton backend tracks dtype currently + if triton_backend: + if name == "masked": + output_dtype = value.dtype + else: + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + else: + # cpp backend doesnt track dtype yet + output_dtype = None - # we tree_map over the output, so we need to fetch corresponding dtype - if isinstance(output_dtype, (list, tuple)): - output_dtype = output_dtype[output_idx] + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + ) - V.kernel.compute.writeline( - f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" - ) - output_idx += 1 + nonlocal output_idx + if config.test_configs.runtime_triton_dtype_assert and triton_backend: + from torch._inductor.codegen.triton import triton_type - csevar.update_on_args(name, args, kwargs) + # we tree_map over the output, so we need to fetch corresponding dtype + if isinstance(output_dtype, (list, tuple)): + output_dtype = output_dtype[output_idx] - return csevar + V.kernel.compute.writeline( + f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" + ) + output_idx += 1 - return pytree.tree_map(do_cse, value) + csevar.update_on_args(name, args, kwargs) - return inner + return csevar + + return pytree.tree_map(do_cse, value) def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]: """ diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c0898d13c26..7b123f590fa 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -31,6 +31,7 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty from ...utils._sympy.value_ranges import ValueRanges from .. import config, ir, metrics from ..codecache import code_hash, get_path, PyCodeCache +from ..ops_handler import DefaultHandler from ..runtime.benchmarking import benchmarker from ..runtime.hints import ( AutotuneHint, @@ -2872,24 +2873,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): dtype_handler = DtypePropagationOpsHandler() - class CSEProxy: - def __getattr__(self, name: str) -> Callable[..., CSEVariable]: - def inner(*args, **kwargs): - nonlocal helper_name - helper_name += f"_{name}" + class CSEProxy(DefaultHandler): + def _default( + self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + nonlocal helper_name + helper_name += f"_{name}" - output_dtype = getattr( - dtype_handler, - name, - )(*args, **kwargs) + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) - return cse.generate( - helper, - getattr(overrides, name)(*args, **kwargs), - dtype=output_dtype, - ) - - return inner + return cse.generate( + helper, + getattr(overrides, name)(*args, **kwargs), + dtype=output_dtype, + ) with helper.indent(), V.set_ops_handler(CSEProxy()): outputs = fn(*args) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index f66e2e791a1..820e737414d 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -4,7 +4,17 @@ import itertools import logging import re from collections.abc import Sequence -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) from unittest.mock import patch import sympy @@ -15,6 +25,7 @@ from torch.utils._ordered_set import OrderedSet from ..utils._sympy.symbol import make_symbol, SymT from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler from .utils import ( get_dtype_size, reduction_num_outputs, @@ -737,19 +748,16 @@ def canonicalization_prefix() -> str: # ops handler which computes all the free unbacked symbols for an IR -class FreeUnbackedSymbolsOpsHandler: +class FreeUnbackedSymbolsOpsHandler(DefaultHandler): symbols: OrderedSet[sympy.Symbol] def __init__(self) -> None: self.symbols = OrderedSet() - def __getattr__(self, name: str) -> Callable[..., Any]: - def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None: - for a in itertools.chain(args, kwargs.values()): - if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): - self.symbols |= free_unbacked_symbols(a) - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + for a in itertools.chain(args, kwargs.values()): + if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): + self.symbols |= free_unbacked_symbols(a) def indirect_indexing( self, @@ -791,10 +799,12 @@ class FreeUnbackedSymbolsOpsHandler: body() -def _typecheck_FreeUnbackedSymbolsOpsHandler( - h: FreeUnbackedSymbolsOpsHandler, -) -> OpsHandler[None]: - return h +if TYPE_CHECKING: + + class _typecheck_FreeUnbackedSymbolsOpsHandler( + FreeUnbackedSymbolsOpsHandler, OpsHandler[None] + ): + pass def extract_free_unbacked_symbols( diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 310df89ffa3..741864e41e1 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -21,8 +21,9 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing. """ import itertools +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Callable, Literal, Optional, overload, Union +from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union from typing_extensions import TypeAlias import sympy @@ -32,6 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges +from .ops_handler import DefaultHandler, OpsHandler from .sizevars import evaluate_expr from .utils import generate_assert from .virtualized import V @@ -185,7 +187,7 @@ class IndexPropVar: IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]] -class IndexPropagation: +class IndexPropagation(DefaultHandler): """Ops wrapper that tries to propagate constant and index_expr values through the computation. This aims to maximize the compile time simplification possible, and convert @@ -247,19 +249,19 @@ class IndexPropagation: def fallback( self, name: Literal["indirect_indexing"], - args: tuple[Any, ...], + args: Sequence[Any], kwargs: dict[str, Any], ) -> IndexPropVar: ... @overload def fallback( - self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] ) -> IndexPropResult: ... def fallback( - self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] ) -> IndexPropResult: # Fallback to the wrapped handler new_args = [self.unwrap(a) for a in args] @@ -267,7 +269,7 @@ class IndexPropagation: return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) def propagate_sympy( - self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] ) -> IndexPropResult: # Build a new SymPy expression from this ops call def unwrap(a: Union[Any, IndexPropVar]) -> Any: @@ -288,22 +290,19 @@ class IndexPropagation: return self.fallback(name, args, kwargs) return IndexPropVar.new_symbolic(new_expr) - def __getattr__(self, name: str) -> Callable[..., IndexPropResult]: - def inner(*args: Any, **kwargs: Any) -> IndexPropResult: - if not hasattr(SymPyOps, name): - return self.fallback(name, args, kwargs) + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + if not hasattr(SymPyOps, name): + return self.fallback(name, args, kwargs) - var_arguments = [ - a - for a in itertools.chain(args, kwargs.values()) - if isinstance(a, IndexPropVar) - ] - if not all(v.is_symbolic for v in var_arguments): - return self.fallback(name, args, kwargs) + var_arguments = [ + a + for a in itertools.chain(args, kwargs.values()) + if isinstance(a, IndexPropVar) + ] + if not all(v.is_symbolic for v in var_arguments): + return self.fallback(name, args, kwargs) - return self.propagate_sympy(name, args, kwargs) - - return inner + return self.propagate_sympy(name, args, kwargs) def statically_true(self, e): """ @@ -371,3 +370,9 @@ class IndexPropagation: "indirect_indexing", (index, size, check, wrap_neg), {} ).value return indirect_var + + +if TYPE_CHECKING: + + class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]): + pass diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index aab9d318762..feb88a09e75 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -17,6 +17,7 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs from .virtualized import ops, V @@ -653,11 +654,11 @@ class LoopBodyBlock: return copy -class CountOps: +class CountOps(DefaultHandler): def __init__(self, inner: Any, counts: collections.Counter[str]): self._inner = inner self._counts = counts - def __getattr__(self, name): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: self._counts[name] += 1 - return getattr(self._inner, name) + return getattr(self._inner, name)(*args, **kwargs) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index d3c9725f59a..f82c84afdad 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -64,6 +64,7 @@ from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union from torch.utils._ordered_set import OrderedSet from .ops_handler import ( # noqa: F401 + DefaultHandler, KernelFormatterHandler, MockHandler, OpsHandler, @@ -274,18 +275,15 @@ class OpsValue: return ops.bitwise_left_shift(self, n) -class OpsWrapper: +class OpsWrapper(DefaultHandler): """This wraps any returned IR values into an `OpsValue` instance, so that we can overload the magic methods for writing mathematical expressions fluently. """ - def __getattr__(self, name): - def inner(*args, **kwargs): - new_args = [OpsWrapper._unwrap(a) for a in args] - new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} - return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + new_args = [OpsWrapper._unwrap(a) for a in args] + new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} + return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) @staticmethod def _unwrap(x): From 06604c4ec1e0ead7b939ecb8ca569f0ccbd00c64 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Feb 2025 13:32:54 -0800 Subject: [PATCH 05/28] [inductor] Refactor op handlers part 5 (#146257) This makes OpHandler just a normal class using inheritance, and removes typing workarounds needed because it wasn't Pull Request resolved: https://github.com/pytorch/pytorch/pull/146257 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255 --- test/inductor/test_op_completeness.py | 20 +- test/test_sympy_utils.py | 3 +- .../_inductor/analyze_preserves_zero_mask.py | 32 +- torch/_inductor/bounds.py | 122 +++++- torch/_inductor/codegen/common.py | 21 +- torch/_inductor/codegen/halide.py | 8 +- torch/_inductor/codegen/mps.py | 8 +- torch/_inductor/codegen/triton.py | 8 +- torch/_inductor/dependencies.py | 22 +- torch/_inductor/index_propagation.py | 10 +- torch/_inductor/loop_body.py | 4 +- torch/_inductor/ops_handler.py | 403 ++++++++---------- torch/_inductor/output_code.py | 8 - torch/_inductor/subgraph_lowering.py | 4 +- torch/_inductor/virtualized.py | 10 +- torch/utils/_sympy/value_ranges.py | 104 +---- 16 files changed, 332 insertions(+), 455 deletions(-) diff --git a/test/inductor/test_op_completeness.py b/test/inductor/test_op_completeness.py index 04fac4870fd..23d59a78941 100644 --- a/test/inductor/test_op_completeness.py +++ b/test/inductor/test_op_completeness.py @@ -5,19 +5,23 @@ from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides from torch._inductor.codegen.halide import HalideOverrides from torch._inductor.codegen.mps import MetalOverrides from torch._inductor.codegen.triton import TritonKernelOverrides -from torch._inductor.ops_handler import list_ops, OP_NAMES +from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler from torch._inductor.test_case import TestCase class TestOpCompleteness(TestCase): def verify_ops_handler_completeness(self, handler): - op_names = list_ops(handler) - if OP_NAMES == op_names: - return - print(f"Missing ops: {OP_NAMES - op_names}") - print(f"Extra ops: {op_names - OP_NAMES}") - self.assertEqual(", ".join(OP_NAMES - op_names), "") - self.assertEqual(", ".join(op_names - OP_NAMES), "") + for op in OP_NAMES: + self.assertIsNot( + getattr(handler, op), + getattr(OpsHandler, op), + msg=f"{handler} must implement {op}", + ) + extra_ops = list_ops(handler) - OP_NAMES + if extra_ops: + raise AssertionError( + f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`" + ) def test_triton_overrides(self): self.verify_ops_handler_completeness(TritonKernelOverrides) diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 5cd93027417..dddb73c2851 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -34,7 +34,8 @@ from torch.utils._sympy.reference import ( ) from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve -from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import ValueRanges +from torch._inductor.bounds import ValueRangeAnalysis UNARY_OPS = [ diff --git a/torch/_inductor/analyze_preserves_zero_mask.py b/torch/_inductor/analyze_preserves_zero_mask.py index abdf1320bc2..974960b9589 100644 --- a/torch/_inductor/analyze_preserves_zero_mask.py +++ b/torch/_inductor/analyze_preserves_zero_mask.py @@ -9,7 +9,7 @@ from torch._inductor import config from torch._inductor.dtype_propagation import DtypePropagationOpsHandler from torch._inductor.index_propagation import SymPyOps, TypedExpr -from .ops_handler import DefaultHandler, OpsHandler +from .ops_handler import DefaultHandler from .virtualized import StoreMode, V @@ -32,27 +32,22 @@ class PreservesZeros(SymPyOps, DefaultHandler): self.store_preserves_zeros: Optional[bool] = None self.dtype_prop = DtypePropagationOpsHandler() - @staticmethod - def load(name: str, index: sympy.Expr) -> TypedExpr: + def load(self, name: str, index: sympy.Expr) -> TypedExpr: # In prologue fusion, all loads get broadcasted - dtype = V.get_ops_handler().dtype_prop.load(name, index) + dtype = self.dtype_prop.load(name, index) return TypedExpr( sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype ) - @staticmethod def store( - name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None + self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None ) -> None: - self = V.get_ops_handler() assert isinstance(self, PreservesZeros) # should only have a single store in prologue assert self.store_preserves_zeros is None self.store_preserves_zeros = value.is_constant() and value.expr == 0 - @staticmethod - def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: - self = V.get_ops_handler() + def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr: return construct_symbol(next(self.count), torch.int32) def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: @@ -65,12 +60,6 @@ class PreservesZeros(SymPyOps, DefaultHandler): return TypedExpr(construct_symbol(next(self.count), dtype), dtype) -if TYPE_CHECKING: - - class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]): - pass - - def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool: """ Does this prologue preserve zero masks @@ -100,9 +89,8 @@ class RecordLowPrecisionOps(DefaultHandler): "constant", ) - @staticmethod - def load(name: str, index: sympy.Expr) -> DTypeContainer: - return DTypeContainer(V.get_ops_handler().dtype_prop.load(name, index)) + def load(self, name: str, index: sympy.Expr) -> DTypeContainer: + return DTypeContainer(self.dtype_prop.load(name, index)) @staticmethod def store( @@ -131,12 +119,6 @@ class RecordLowPrecisionOps(DefaultHandler): return out -if TYPE_CHECKING: - - class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]): - pass - - def low_prec_float(dtype: torch.dtype) -> bool: return dtype.is_floating_point and dtype.itemsize < 4 diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 3df87ada0dd..69c331646f8 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,14 +1,22 @@ import logging import operator from functools import partial -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union +import sympy from sympy import Expr import torch -from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRanges, +) +from ..utils._sympy.functions import PowByNatural +from ..utils._sympy.numbers import int_oo from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock +from .ops_handler import DefaultHandler, ReductionType, StoreMode from .utils import cache_on_self, dominated_nodes from .virtualized import V @@ -139,3 +147,113 @@ class BoundVars: # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) self.replacement_vals[name] = bound return bound + + +class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler): + def __init__(self) -> None: + self.name = "ValueRangeAnalysis" + boolean_operators = ( + "xor", + "logical_and", + "logical_or", + "logical_not", + ) + for op in boolean_operators: + setattr(self, op, self.bool_handler) + + @staticmethod + def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]: + # just assuming bools can have both values + return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # many ops are unlikely to show up in optimizable indexing compute, + # so we dont have full coverage + return ValueRanges.unknown() + + def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]: + return ValueRanges.unknown() + + def store( + self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None + ) -> None: + return + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Any, + ) -> ValueRanges[Any]: + return ValueRanges.unknown() + + @classmethod + def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]: + assert isinstance(index, ValueRanges) + return cls.to_dtype(index, dtype) + + @staticmethod + def to_dtype( + x: Any, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> ValueRanges[Any]: + x = ValueRanges.wrap(x) + + if dtype == torch.bool: + if x.is_singleton(): + return ValueRanges.wrap(x.lower != 0) + elif x.is_bool: + return x + elif 0 not in x: + return ValueRanges.wrap(sympy.true) + else: + return ValueRanges(sympy.false, sympy.true) + + def cast(x: Any, dtype: torch.dtype) -> sympy.Expr: + # dtype is int or float + if dtype.is_floating_point: + return sympy.Float(x) + else: + if x in (int_oo, -int_oo): + return x + try: + return sympy.Integer(x) + except TypeError: + # inf cannot be cast to Integer + return x + + if x.is_bool: + if x.is_singleton(): + val = 1 if x.lower else 0 + return ValueRanges.wrap(cast(val, dtype)) + else: + return ValueRanges(cast(0, dtype), cast(1, dtype)) + else: + # int to float or float to int + return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) + + @staticmethod + def square(x: Any) -> ValueRanges[Any]: + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) + + @staticmethod + def neg(x: Any) -> ValueRanges[Any]: + return ValueRanges.decreasing_map(x, operator.neg) + + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds + @classmethod + def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]: + x = cls.truediv(a, b) + if x == ValueRanges.unknown(): + return x + + return cls.trunc(x) + + @classmethod + def sub(cls, a: Any, b: Any) -> ValueRanges[Any]: + return cls.add(a, cls.neg(b)) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index ab5280e1d5b..dbd02188665 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -37,11 +37,11 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT -from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from .. import config, metrics from ..dtype_propagation import DtypePropagationOpsHandler -from ..ops_handler import BasicMathOps, DefaultHandler +from ..ops_handler import BasicMathOpsMixin, DefaultHandler from ..utils import ( boolean_ops, DeferredLineBase, @@ -764,7 +764,7 @@ def _all_in_parens(string: str) -> bool: return True -class OpOverrides(BasicMathOps, OpDecompositions): +class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): @staticmethod def paren(string: OpVarT) -> OpVarT: if ( @@ -1235,12 +1235,6 @@ pointwise_overrides_data: dict[str, OverridesData] = dict( ) -if TYPE_CHECKING: - - class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class DeferredLine(DeferredLineBase): """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" @@ -2268,6 +2262,8 @@ class CSEProxy(DefaultHandler): def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): super().__init__() + from ..bounds import ValueRangeAnalysis + self.vr_analysis = ValueRangeAnalysis() self.kernel = kernel self.parent_handler = parent_handler @@ -2338,6 +2334,7 @@ class CSEProxy(DefaultHandler): If the variable comes from an FX node, we forward the bound we have already computed Else, if the variable when codegen'ing another op, we try to compute its bounds """ + from ..bounds import ValueRangeAnalysis from ..select_algorithm import TritonTemplateKernel if isinstance(V.kernel, TritonTemplateKernel): @@ -2575,9 +2572,3 @@ class CSEProxy(DefaultHandler): sorter, sorter_indices, ) - - -if TYPE_CHECKING: - - class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]): - pass diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 560e75c648f..f2bdebf3c1b 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -33,7 +33,7 @@ from ..utils import ( sympy_index_symbol, sympy_subs, ) -from ..virtualized import _ops as ops, OpsHandler, V +from ..virtualized import _ops as ops, V from .common import ( BackendFeature, CSEVariable, @@ -563,12 +563,6 @@ class HalideOverrides(OpOverrides): HalideOverrides._initialize_pointwise_overrides("halide") -if TYPE_CHECKING: - - class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class HalideCSEVariable(CSEVariable): undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 86cbb6f5361..dd3ff699e8a 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: import sympy - from ..ops_handler import OpsHandler, ReductionType, StoreMode + from ..ops_handler import ReductionType, StoreMode from ..scheduler import Scheduler, SchedulerNode from .common import OpVarT @@ -367,12 +367,6 @@ class MetalOverrides(OpOverrides): MetalOverrides._initialize_pointwise_overrides("mps") -if TYPE_CHECKING: - - class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong - - class MetalKernel(SIMDKernel): overrides = MetalOverrides # type: ignore[assignment] suffix = ";" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 7b123f590fa..e0c1f988479 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -61,7 +61,7 @@ from ..utils import ( triton_version_uses_attrs_dict, upcast_compute_type, ) -from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V +from ..virtualized import _ops as ops, ReductionType, StoreMode, V from ..wrapper_benchmark import get_kernel_category_by_source_code from .block_analysis import BlockPatternMatcher from .common import ( @@ -1428,12 +1428,6 @@ class TritonKernelOverrides(TritonOverrides): return (mantissa, exponent) -if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class HelperFunctions: """An ordered set of helper functions.""" diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 820e737414d..36000a50cb8 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -4,17 +4,7 @@ import itertools import logging import re from collections.abc import Sequence -from typing import ( - Any, - Callable, - Iterable, - List, - Optional, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union from unittest.mock import patch import sympy @@ -34,7 +24,7 @@ from .utils import ( sympy_subs, VarRanges, ) -from .virtualized import OpsHandler, ReductionType, V +from .virtualized import ReductionType, V T = TypeVar("T") @@ -799,14 +789,6 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler): body() -if TYPE_CHECKING: - - class _typecheck_FreeUnbackedSymbolsOpsHandler( - FreeUnbackedSymbolsOpsHandler, OpsHandler[None] - ): - pass - - def extract_free_unbacked_symbols( fn: Callable[..., Any], index: Sequence[sympy.Expr], diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 741864e41e1..2e564041340 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -23,7 +23,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing. import itertools from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union +from typing import Any, Literal, Optional, overload, Union from typing_extensions import TypeAlias import sympy @@ -33,7 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges -from .ops_handler import DefaultHandler, OpsHandler +from .ops_handler import DefaultHandler from .sizevars import evaluate_expr from .utils import generate_assert from .virtualized import V @@ -370,9 +370,3 @@ class IndexPropagation(DefaultHandler): "indirect_indexing", (index, size, check, wrap_neg), {} ).value return indirect_var - - -if TYPE_CHECKING: - - class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]): - pass diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index feb88a09e75..afee8988253 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import index_prevent_reordering -from .ops_handler import DefaultHandler +from .ops_handler import DefaultHandler, OpsHandler from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs from .virtualized import ops, V @@ -655,7 +655,7 @@ class LoopBodyBlock: class CountOps(DefaultHandler): - def __init__(self, inner: Any, counts: collections.Counter[str]): + def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]): self._inner = inner self._counts = counts diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 22ce7154c6b..5338372f6af 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -4,17 +4,7 @@ from __future__ import annotations import itertools import re import warnings -from typing import ( - Any, - Callable, - Literal, - NamedTuple, - Optional, - TYPE_CHECKING, - TypeVar, - Union, -) -from typing_extensions import Protocol +from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from unittest.mock import patch import sympy @@ -48,12 +38,8 @@ def _arg_str(a: object) -> str: return str(a) -# NB: This is not done as a parent class, because our ops handlers -# implementations make heavy use of __getattr__ magic, and pre-existing -# stubs for methods would interfere with this mechanism. -# # See OpDecompositions for superclass that desugars operations like reciprocal/square. -class OpsHandler(Protocol[T]): +class OpsHandler(Generic[T]): """ Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, as well as the contract for op handlers. The type T signifies the domain @@ -77,49 +63,30 @@ class OpsHandler(Protocol[T]): ops handlers. Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides), - which means you will get type errors if you subclass OpsHandler since mypy doesn't know - about the methods added via metaprogramming and thinks the class is still abstract. - Instead, you should add a block like: - - if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - Which will check the signatures of non-meta-programmed methods and gives decent error messages. - - Some older parts of the code use a pattern like: - - def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: - return h - - This pattern only works if the class defines a __getattr__ method, which we are moving away from. - Additionally, this pattern generates horrible error messages if the signatures are wrong. - It gives zero information about what the problem is, which makes the pattern harmful. - - Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all - operators are implemented after all the metaprogramming has run. + which means you will not get type errors for those methods. We have tests in + test/inductor/test_op_completeness.py which check that all operators are implemented after + all the metaprogramming has run. """ def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: """Produces a scalar constant of type dtype.""" - ... + raise NotImplementedError def load_seed(self, name: str, offset: T) -> T: """Computes inductor_prims.lookup_seed.""" - ... + raise NotImplementedError def rand(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" - ... + raise NotImplementedError def randn(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" - ... + raise NotImplementedError def randint64(self, seed: T, offset: T, low: T, high: T) -> T: """Computes inductor_prims.randint. offset has dtype int32.""" - ... + raise NotImplementedError def masked(self, mask: T, body: Callable[[], T], other: T) -> T: """ @@ -133,13 +100,13 @@ class OpsHandler(Protocol[T]): Contrast this with ops.where, which can multiplex between two values that have been unconditionally computed. """ - ... + raise NotImplementedError def where(self, condition: T, input: T, other: T) -> T: """ Computes torch.where: when condition is true, return input; otherwise return other. """ - ... + raise NotImplementedError def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: """ @@ -147,7 +114,7 @@ class OpsHandler(Protocol[T]): an indexing expression, thus the name; however, it can also be used in non-indexing situations. """ - ... + raise NotImplementedError def to_dtype( self, @@ -160,7 +127,7 @@ class OpsHandler(Protocol[T]): Convert x to dtype. src_dtype can be optionally set to specify what the original dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). """ - ... + raise NotImplementedError def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: """ @@ -174,38 +141,38 @@ class OpsHandler(Protocol[T]): int64 depending on if we've shown that all the indexing operations can be done in int32. """ - ... + raise NotImplementedError def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with ceiling semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def floor_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with ceiling semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def round_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with round-to-even semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) src_dtype must be the original type of x. """ - ... + raise NotImplementedError def identity(self, x: T) -> T: """ Returns x as is. This is used to trigger CSE. """ - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operations are only available in a "kernel" context. Check @@ -227,13 +194,13 @@ class OpsHandler(Protocol[T]): NB: This is typically mandatory to implement for any analysis, because you MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). """ - ... + raise NotImplementedError def load(self, name: str, index: sympy.Expr) -> T: """ Load from the memory location 'name', offset by some indexing expression 'index'. """ - ... + raise NotImplementedError def store( self, @@ -246,7 +213,7 @@ class OpsHandler(Protocol[T]): Store 'value' to the memory location 'name' offset by 'expr'. If specified, 'mode' can require the store to be an atomic addition. """ - ... + raise NotImplementedError # TODO: Better explain how the "collective" semantics of these ops; # remember that the input value is a scalar, you can't reduce on it in the @@ -268,7 +235,7 @@ class OpsHandler(Protocol[T]): function returns multiple outputs; consult reduction_num_outputs to determine the amount in metaprogramming applications. """ - ... + raise NotImplementedError # TODO: in practice, this seems to actually return None, but not returning # a T makes common __getattr__ idioms not type correctly. Figure out if @@ -278,7 +245,7 @@ class OpsHandler(Protocol[T]): Store the fully accumulated result of 'reduction' to the memory location 'name' offset by 'expr'. """ - ... + raise NotImplementedError def scan( self, @@ -290,7 +257,7 @@ class OpsHandler(Protocol[T]): Perform an associative scan on 'value'. """ # TODO: Improve the description with some pseudocode - ... + raise NotImplementedError def sort( self, @@ -302,7 +269,7 @@ class OpsHandler(Protocol[T]): """ Sort values along the reduction dimension. """ - ... + raise NotImplementedError def bucketize( self, @@ -315,231 +282,231 @@ class OpsHandler(Protocol[T]): sorter_indices: Optional[T] = None, ) -> T: # See [Note: Inductor bucketize op] - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # The following ops have semantics that correspond exactly to the torch # operation with the same corresponding name. def abs(self, x0: T) -> T: - ... + raise NotImplementedError def exp(self, x0: T) -> T: - ... + raise NotImplementedError def exp2(self, x0: T) -> T: - ... + raise NotImplementedError def expm1(self, x0: T) -> T: - ... + raise NotImplementedError def sqrt(self, x0: T) -> T: - ... + raise NotImplementedError def relu(self, x0: T) -> T: - ... + raise NotImplementedError def minimum(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def maximum(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def cos(self, x0: T) -> T: - ... + raise NotImplementedError def sin(self, x0: T) -> T: - ... + raise NotImplementedError def lgamma(self, x0: T) -> T: - ... + raise NotImplementedError def erf(self, x0: T) -> T: - ... + raise NotImplementedError def cosh(self, x0: T) -> T: - ... + raise NotImplementedError def sinh(self, x0: T) -> T: - ... + raise NotImplementedError def acos(self, x0: T) -> T: - ... + raise NotImplementedError def acosh(self, x0: T) -> T: - ... + raise NotImplementedError def asin(self, x0: T) -> T: - ... + raise NotImplementedError def asinh(self, x0: T) -> T: - ... + raise NotImplementedError def atan2(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def atan(self, x0: T) -> T: - ... + raise NotImplementedError def atanh(self, x0: T) -> T: - ... + raise NotImplementedError def copysign(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def erfc(self, x0: T) -> T: - ... + raise NotImplementedError def erfinv(self, x0: T) -> T: - ... + raise NotImplementedError def frexp(self, x0: T): - ... + raise NotImplementedError def hypot(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def log10(self, x0: T) -> T: - ... + raise NotImplementedError def log2(self, x0: T) -> T: - ... + raise NotImplementedError def nextafter(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_and(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_not(self, x0: T) -> T: - ... + raise NotImplementedError def logical_or(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_and(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_not(self, x0: T) -> T: - ... + raise NotImplementedError def bitwise_or(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_left_shift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_right_shift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def rsqrt(self, x0: T) -> T: - ... + raise NotImplementedError def log1p(self, x0: T) -> T: - ... + raise NotImplementedError def tan(self, x0: T) -> T: - ... + raise NotImplementedError def tanh(self, x0: T) -> T: - ... + raise NotImplementedError def sigmoid(self, x0: T) -> T: - ... + raise NotImplementedError def signbit(self, x0: T) -> T: - ... + raise NotImplementedError def fmod(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def log(self, x0: T) -> T: - ... + raise NotImplementedError def isinf(self, x0: T) -> T: - ... + raise NotImplementedError def isnan(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation # This rounds half to even to break ties def round(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: - ... + raise NotImplementedError def sign(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: - ... + raise NotImplementedError def neg(self, x0: T) -> T: - ... + raise NotImplementedError def reciprocal(self, x0: T) -> T: - ... + raise NotImplementedError def eq(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def ne(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def lt(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def gt(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def le(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def ge(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def add(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def sub(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def mul(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def and_(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def or_(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # These are metaprogrammed by MockHandler._init_cls def lshift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def rshift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These are "special" operators. These only exist if the target @@ -547,124 +514,124 @@ class OpsHandler(Protocol[T]): # pointwise_overrides_data. def airy_ai(self, x: T) -> T: - ... + raise NotImplementedError def bessel_j0(self, x: T) -> T: - ... + raise NotImplementedError def bessel_j1(self, x: T) -> T: - ... + raise NotImplementedError def bessel_y0(self, x: T) -> T: - ... + raise NotImplementedError def bessel_y1(self, x: T) -> T: - ... + raise NotImplementedError def digamma(self, x: T) -> T: - ... + raise NotImplementedError def erfcx(self, x: T) -> T: - ... + raise NotImplementedError def fma(self, x: T, y: T, z: T) -> T: - ... + raise NotImplementedError def igamma(self, x: T, y: T) -> T: - ... + raise NotImplementedError def igammac(self, x: T, y: T) -> T: - ... + raise NotImplementedError def gammainc(self, x: T, y: T) -> T: - ... + raise NotImplementedError def gammaincc(self, x: T, y: T) -> T: - ... + raise NotImplementedError def i0(self, x: T) -> T: - ... + raise NotImplementedError def i0e(self, x: T) -> T: - ... + raise NotImplementedError def i1(self, x: T) -> T: - ... + raise NotImplementedError def i1e(self, x: T) -> T: - ... + raise NotImplementedError def log_ndtr(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_i0(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_i1(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_k0(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_k1(self, x: T) -> T: - ... + raise NotImplementedError def ndtr(self, x: T) -> T: - ... + raise NotImplementedError def ndtri(self, x: T) -> T: - ... + raise NotImplementedError def polygamma(self, x: T, y: T) -> T: - ... + raise NotImplementedError def scaled_modified_bessel_k0(self, x: T) -> T: - ... + raise NotImplementedError def scaled_modified_bessel_k1(self, x: T) -> T: - ... + raise NotImplementedError def spherical_bessel_j0(self, x: T) -> T: - ... + raise NotImplementedError def zeta(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_t(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_u(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_v(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_w(self, x: T, y: T) -> T: - ... + raise NotImplementedError def legendre_polynomial_p(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: - ... + raise NotImplementedError def hermite_polynomial_h(self, x: T, y: T) -> T: - ... + raise NotImplementedError def hermite_polynomial_he(self, x: T, y: T) -> T: - ... + raise NotImplementedError def laguerre_polynomial_l(self, x: T, y: T) -> T: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operators are a bit special, because they are conventionally @@ -675,42 +642,42 @@ class OpsHandler(Protocol[T]): """C-style trunc division between integers only. Computes the true division of two numbers and rounds the result to zero. """ - ... + raise NotImplementedError def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the true division of two numbers and floors the result. If you want floor division for floats, do regular truediv and floor the result. """ - ... + raise NotImplementedError def truediv(self, x0: T, x1: T) -> T: """True division between floats. Integer inputs are NOT valid. To do Python-style (int, int) -> float division, use int_truediv""" - ... + raise NotImplementedError def int_truediv(self, x0: T, x1: T) -> T: """True division between integers. This is NOT the same as promoting to float and doing integer division, there is a bespoke algorithm for doing the division in higher precision than the above. """ - ... + raise NotImplementedError def mod(self, x0: T, x1: T) -> T: """C-style modulus, take sign from LHS (x0).""" - ... + raise NotImplementedError def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" - ... + raise NotImplementedError def square(self, x0: T) -> T: - ... + raise NotImplementedError def check_bounds( self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool ) -> None: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are @@ -726,25 +693,25 @@ class OpsHandler(Protocol[T]): # for many analyses it's not conveniently available.) def libdevice_abs(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_exp(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sqrt(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_cos(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sin(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sigmoid(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_log(self, x0: T) -> T: - ... + raise NotImplementedError # halide-only def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T: @@ -760,15 +727,15 @@ class OpsHandler(Protocol[T]): is_pure: bool = True, pack: int = 1, ) -> T: - ... + raise NotImplementedError def output(self, x0: T) -> None: """This is a fake op used in analysis but not codegen""" - ... + raise NotImplementedError def placeholder(self, index: int) -> T: """This is a fake op used in analysis but not codegen""" - ... + raise NotImplementedError _ignore_op_re = re.compile(r"_.*|paren").fullmatch @@ -781,7 +748,7 @@ def list_ops(cls: type[Any]): OP_NAMES = list_ops(OpsHandler) -class DefaultHandler: +class DefaultHandler(OpsHandler[Any]): def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: """ Default implementation for all ops. Override in a subclass to @@ -850,13 +817,7 @@ class NoopHandler(DefaultHandler): return sympy.S.Zero -if TYPE_CHECKING: - - class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]): - pass # mypy will error if we got any of the signatures wrong - - -class BasicMathOps: +class BasicMathOpsMixin: @staticmethod def add(a, b): return f"{a} + {b}" @@ -935,7 +896,7 @@ class BasicMathOps: return f"-{a}" -class MockHandler(BasicMathOps, DefaultHandler): +class MockHandler(BasicMathOpsMixin, DefaultHandler): name = "MockHandler" def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: @@ -971,14 +932,8 @@ class MockHandler(BasicMathOps, DefaultHandler): return sympy_index_symbol(str(index_var)) -if TYPE_CHECKING: - - class _typecheck_MockHandler(MockHandler, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class KernelFormatterHandler(DefaultHandler): - def __init__(self, parent_handler): + def __init__(self, parent_handler: OpsHandler[Any]): self.parent_handler = parent_handler self._output = IndentedBuffer(1) self.var_counter = itertools.count() @@ -1042,14 +997,8 @@ class KernelFormatterHandler(DefaultHandler): return self._output.getvalue() -if TYPE_CHECKING: - - class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class WrapperHandler(DefaultHandler): - def __init__(self, inner: Any): + def __init__(self, inner: OpsHandler[Any]): self._inner = inner def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: @@ -1074,11 +1023,11 @@ class OpCountResult(NamedTuple): class OpCounterCSE(DefaultHandler): """Shim to count how many ops are used""" - def __init__(self, inner): + def __init__(self, inner: OpsHandler[Any]): super().__init__() self.parent_handler = inner self.op_count = 0 - self.var_names = {} + self.var_names: dict[str, str] = {} self._used_ops: OrderedSet[str] = OrderedSet() self._read_names: list[str] = [] self._nontrivial_read_count = 0 @@ -1152,26 +1101,16 @@ class OpCounterCSE(DefaultHandler): ) -if TYPE_CHECKING: - - class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class ExtractConstantsHandler(NoopHandler): - def __init__(self, device): + def __init__(self, device: Optional[torch.device]): self.device = device def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant: from torch._inductor import ir - return ir.Constant(value=value, dtype=dtype, device=self.device) - - -if TYPE_CHECKING: - - class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong + return ir.Constant( + value=value, dtype=dtype, device=self.device or torch.get_default_device() + ) class SimpleCSEHandler(WrapperHandler): @@ -1204,9 +1143,3 @@ class SimpleCSEHandler(WrapperHandler): val = getattr(self._inner, name)(*args, **kwargs) self.cse_cache[key] = val return val - - -if TYPE_CHECKING: - - class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 393e282d03c..66f60b12f16 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -564,10 +564,6 @@ class CompiledFxGraph(OutputCode): return artifact_path -def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode: - return h - - @dataclasses.dataclass class CompiledAOTI(OutputCode): """ @@ -591,10 +587,6 @@ class CompiledAOTI(OutputCode): pass -def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode: - return h - - @dataclasses.dataclass class MockFXGraphCacheOutput(OutputCode): gm: Any = None diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index 0166534e5fb..d7992385735 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -149,8 +149,8 @@ class TracingOpsHandler(WrapperHandler): def placeholder(self, idx: int) -> torch.fx.Proxy: return self.placeholders[idx] - def output(self, *args: tuple[object]) -> torch.fx.Node: - return self.tracer.create_node( + def output(self, *args: tuple[object]) -> None: + self.tracer.create_node( "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} ) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index f82c84afdad..1ee1ef5a744 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -306,9 +306,7 @@ class OpsWrapper(DefaultHandler): return _ops.indirect_indexing(index, size, check, wrap_neg) -# we lie about the type of ops so the rest of the codebase typecheck properly -# DefaultHandler implements the OpsHandler protocol via metaprogramming -ops = cast(OpsHandler[Any], OpsWrapper()) +ops: OpsHandler[Any] = OpsWrapper() class _V: @@ -316,8 +314,10 @@ class _V: KernelFormatterHandler = KernelFormatterHandler WrapperHandler = WrapperHandler - set_ops_handler: Callable[[Any], Any] = _ops._set_handler - get_ops_handler: Callable[[], Any] = _ops._get_handler + set_ops_handler: Callable[ + [OpsHandler[Any]], AbstractContextManager[None] + ] = _ops._set_handler + get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler get_real_inputs: Callable[[], Any] = _real_inputs._get_handler diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index eb85b6798ea..784f9e7ba05 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -49,7 +49,7 @@ from .numbers import int_oo, IntInfinity, NegativeIntInfinity log = logging.getLogger(__name__) -__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"] +__all__ = ["ValueRanges", "bound_sympy"] _T = TypeVar("_T", sympy.Expr, SympyBoolean) @@ -1004,108 +1004,6 @@ class SymPyValueRangeAnalysis: return ValueRanges.increasing_map(x, TruncToFloat) -class ValueRangeAnalysis(SymPyValueRangeAnalysis): - def __init__(self) -> None: - self.name = "ValueRangeAnalysis" - boolean_operators = ( - "xor", - "logical_and", - "logical_or", - "logical_not", - ) - for op in boolean_operators: - setattr(self, op, self.bool_handler) - - @staticmethod - def bool_handler(*args, **kwargs): - # just assuming bools can have both values - return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] - - @staticmethod - def default_handler(*args, **kwargs): - # many ops are unlikely to show up in optimizable indexing compute, - # so we dont have full coverage - return ValueRanges.unknown() - - def load(self, name: str, index: sympy.Expr): - return ValueRanges.unknown() - - def store(self, name, index, value, mode=None): - return - - def reduction(self, name, dtype, src_dtype, reduction_type, index, value): - return ValueRanges.unknown() - - @classmethod - def index_expr(cls, index, dtype): - assert isinstance(index, ValueRanges) - return cls.to_dtype(index, dtype) - - @staticmethod - def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): - x = ValueRanges.wrap(x) - - if dtype == torch.bool: - if x.is_singleton(): - return ValueRanges.wrap(x.lower != 0) - elif x.is_bool: - return x - elif 0 not in x: - return ValueRanges.wrap(sympy.true) - else: - return ValueRanges(sympy.false, sympy.true) - - def cast(x, dtype): - # dtype is int or float - if dtype.is_floating_point: - return sympy.Float(x) - else: - if x in (int_oo, -int_oo): - return x - try: - return sympy.Integer(x) - except TypeError: - # inf cannot be cast to Integer - return x - - if x.is_bool: - if x.is_singleton(): - val = 1 if x.lower else 0 - return ValueRanges.wrap(cast(val, dtype)) - else: - return ValueRanges(cast(0, dtype), cast(1, dtype)) - else: - # int to float or float to int - return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) - - @staticmethod - def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) - - @staticmethod - def neg(x): - return ValueRanges.decreasing_map(x, operator.neg) - - # TODO: this is slightly inaccurate because truncdiv operates at integer - # precision, but we're going through float truediv which means we can - # potentially lose precision on the bounds - @classmethod - def truncdiv(cls, a, b): - x = cls.truediv(a, b) - if x == ValueRanges.unknown(): - return x - - return cls.trunc(x) - - @classmethod - def sub(cls, a, b): - return cls.add(a, cls.neg(b)) - - def __getattr__(self, name): - log.debug("unhandled ValueRange op %s", name) - return self.default_handler - - def bound_sympy( expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: From d35f6b2339384126559938ede05d89b487597ee2 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Feb 2025 13:32:55 -0800 Subject: [PATCH 06/28] [inductor] Minor compile time optimizations in DefaultHandler (#146282) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146282 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255, #146257 --- .../pr_time_benchmarks/benchmark_runner.sh | 0 torch/_inductor/codegen/common.py | 2 +- torch/_inductor/dtype_propagation.py | 2 +- torch/_inductor/loop_body.py | 4 +- torch/_inductor/ops_handler.py | 41 +++++++++++++++++-- 5 files changed, 42 insertions(+), 7 deletions(-) mode change 100644 => 100755 benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh old mode 100644 new mode 100755 diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index dbd02188665..fec37fb6002 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -948,7 +948,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" ) - def output(self, x0: OpVarT) -> None: + def output(self, *args: OpVarT) -> None: raise AssertionError( f"{type(self).__name__}: ops.output should not appear at codegen time" ) diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index 5b45943b940..256079c8071 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -368,7 +368,7 @@ class DtypePropagationOpsHandler: ) -> None: return None - def output(self, x: DTypeArg) -> None: + def output(self, *args: DTypeArg) -> None: raise AssertionError( f"{type(self).__name__}: ops.output should not appear here" ) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index afee8988253..4968544d80f 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -602,8 +602,8 @@ class LoopBodyBlock: return var @staticmethod - def output(result): - tracer.create_proxy("output", "output", (result,), {}) + def output(*result): + tracer.create_proxy("output", "output", result, {}) tracer = LightTracer() proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 5338372f6af..0118d29368c 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -1,9 +1,11 @@ # mypy: allow-untyped-defs from __future__ import annotations +import inspect import itertools import re import warnings +from io import StringIO from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from unittest.mock import patch @@ -729,7 +731,7 @@ class OpsHandler(Generic[T]): ) -> T: raise NotImplementedError - def output(self, x0: T) -> None: + def output(self, *args: T) -> None: """This is a fake op used in analysis but not codegen""" raise NotImplementedError @@ -755,7 +757,7 @@ class DefaultHandler(OpsHandler[Any]): provide generic op behavior. Args: - target: name of the op, see OpHandler.target + name: name of the op, see OpHandler.{name} args: positional args passed to the op kwargs: keyword args passed to the op @@ -783,8 +785,41 @@ class DefaultHandler(OpsHandler[Any]): @classmethod def _init_cls(cls): + """ + Here we codegen many functions of the form: + + def add(self, a, b): + return self._default('add', (a, b), {}) + + and install them in cls. This is the same as _call_default above, + but is about 1.2x faster since CPython varargs parsing is slow. + """ + code = StringIO() for target in OP_NAMES: - setattr(cls, target, cls._call_default(target)) + sig = inspect.signature(getattr(OpsHandler, target)) + if all( + p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is inspect.Parameter.empty + for p in sig.parameters.values() + ): + self_arg, *args = sig.parameters.keys() + assert self_arg == "self" + code.write( + f""" + def {target}(self, {', '.join(args)}): + return self._default({target!r}, ({', '.join(args)}, ), {{}}) + """.strip() + ) + code.write("\n\n") + else: + # slower fallback for ops with default or variadic arguments + setattr(cls, target, cls._call_default(target)) + + ctx: dict[str, Any] = {} + exec(code.getvalue(), ctx) + for target, impl in ctx.items(): + if target in OP_NAMES: + setattr(cls, target, impl) DefaultHandler._init_cls() From c098385cb3e68682c1f1762ff461d3d59b34651d Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Feb 2025 13:32:55 -0800 Subject: [PATCH 07/28] [inductor] Refactor CaptureIndexing into global scope (#146297) And inline SimplifyIndexing into it CaptureIndexing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146297 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255, #146257, #146282 --- torch/_inductor/loop_body.py | 346 ++++++++++++++++++----------------- 1 file changed, 178 insertions(+), 168 deletions(-) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 4968544d80f..c3a3ab7133e 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import index_prevent_reordering -from .ops_handler import DefaultHandler, OpsHandler +from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs from .virtualized import ops, V @@ -440,179 +440,13 @@ class LoopBodyBlock: def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]): self.body = body - - def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): - return tracer.create_proxy( - "call_module", - "get_index", - (body.add_index_expr(expr, mtype, **kwargs),), - {}, - ) - - class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] - name = "CaptureIndexing" - - def load(self, name: str, index: sympy.Expr): - index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) - return self._inner.load(name, index) - - def load_seed(self, name: str, index: int): - assert isinstance(index, int) - body.add_index_expr( - sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name - ) - return self._inner.load_seed(name, index) - - def store(self, name, index, value, mode=None): - index = add_index( - index, MemoryUsageType.STORE, buffer_name=name, mode=mode - ) - return self._inner.store(name, index, value, mode) - - def store_reduction(self, name, index, value): - index = add_index( - index, MemoryUsageType.STORE_REDUCTION, buffer_name=name - ) - return self._inner.store_reduction(name, index, value) - - def reduction(self, dtype, src_dtype, reduction_type, value): - result = self._inner.reduction(dtype, src_dtype, reduction_type, value) - if "welford" in reduction_type: - return tuple(result[i] for i in range(3)) - return result - - def index_expr(self, index, dtype): - if isinstance(index, (int, sympy.Integer)): - return self._inner.constant(int(index), dtype) - index = add_index(index, MemoryUsageType.INDEX_EXPR) - return self._inner.index_expr(index, dtype) - - def check_bounds(self, index, size, lower, upper): - index = add_index(index, MemoryUsageType.CHECK_BOUNDS) - size = add_index(size, MemoryUsageType.CHECK_BOUNDS) - return self._inner.check_bounds(index, size, lower, upper) - - def bucketize( - self, - values: T, - boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], - boundary_indices: T, - indexing_dtype: torch.dtype, - right: bool, - sorter: Optional[tuple[str, sympy.Expr]] = None, - sorter_indices: Optional[T] = None, - ) -> T: - """ - See [Note: Inductor bucketize op] - """ - boundaries = ( - boundaries[0], - add_index( - boundaries[1], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - add_index( - boundaries[2], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - add_index( - boundaries[3], - MemoryUsageType.BUCKETIZE, - buffer_name=boundaries[0], - ), - ) - if sorter is not None: - sorter = ( - sorter[0], - add_index( - sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] - ), - ) - - return self._inner.bucketize( - values, - boundaries, - boundary_indices, - indexing_dtype, - right, - sorter, - sorter_indices, - ) - - @staticmethod - def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): - """ - Recursively capture the masked out body in another LoopBodyBlock - """ - name = self.body.add_submodule(None, "masked_subblock") - self.body.submodules[name] = self.body.bind_masked_shim(name) - self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) - return tracer.create_proxy( - "call_module", name, (mask_proxy, other_proxy), {} - ) - - @staticmethod - def scan( - dtype_proxy, - combine_fn: Callable[ - [tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...] - ], - value_proxy, - ): - shim = self.body.bind_scan_shim(combine_fn) - name = self.body.add_submodule(shim, "scan") - result = tracer.create_proxy( - "call_module", - name, - (dtype_proxy, value_proxy), - {}, - ) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(value_proxy))) - - def sort(self, dtypes, values, stable, descending): - result = self._inner.sort(dtypes, values, stable, descending) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(values))) - - def frexp(self, value_proxy): - result = self._inner.frexp(value_proxy) - # Proxies are iterable, but some methods expect tuples/lists - return (result[0], result[1]) - - @staticmethod - def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): - """ - Flow data from tensors into indexing formulas. - Introduce a call_module to update the indexing. - """ - - var = self.body.add_indirect(size) - set_indirect = self.body.bind_set_indirect_shim( - var, size, check, wrap_neg - ) - tracer.create_proxy( - "call_module", - self.body.add_submodule(set_indirect, f"set_{var}"), - (index_proxy,), - {}, - ) - return var - - @staticmethod - def output(*result): - tracer.create_proxy("output", "output", result, {}) - tracer = LightTracer() proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) from .index_propagation import IndexPropagation - from .sizevars import SimplifyIndexing handler: Any = CountOps( - SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges), + CaptureIndexing(proxy_ops, body, tracer), body.op_counts, ) if config.constant_and_index_propagation: @@ -662,3 +496,179 @@ class CountOps(DefaultHandler): def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: self._counts[name] += 1 return getattr(self._inner, name)(*args, **kwargs) + + +class CaptureIndexing(WrapperHandler): + name = "CaptureIndexing" + + def __init__( + self, + inner: OpsHandler[Any], + body: LoopBody, + tracer: LightTracer, + ): + super().__init__(inner) + self.body = body + self.tracer = tracer + + def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any): + return self.tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + def _simplify(self, expr: sympy.Expr) -> sympy.Expr: + return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges) + + def load(self, name: str, index: sympy.Expr): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + self.body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + index = self._simplify(index) + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = self._add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + """ + See [Note: Inductor bucketize op] + """ + boundaries = ( + boundaries[0], + self._add_index( + boundaries[1], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[2], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[3], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + ) + if sorter is not None: + sorter = ( + sorter[0], + self._add_index( + sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] + ), + ) + + return self._inner.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return self.tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + def scan( + self, + dtype_proxy, + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = self.tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg) + self.tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + def output(self, *result): + self.tracer.create_proxy("output", "output", result, {}) From eee5622b98d547199391803170fa19cea6525448 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 7 Feb 2025 13:32:55 -0800 Subject: [PATCH 08/28] [inductor] Pre-populate cache for simplify_with_ranges return value (#146373) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146373 Approved by: https://github.com/yanboliang, https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255, #146257, #146282, #146297 --- torch/_inductor/sizevars.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 532073a377f..fcc549bf652 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -99,6 +99,8 @@ class SizeVarAllocator: if result is None: result = self._simplify_with_ranges(expr, var_ranges) cache[key] = result + if result != expr: + cache[(result, *var_ranges.items())] = result return result return simplify_with_ranges From a1bfb39a31aff91dfeba8730f5b496cb1e44ceca Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Sat, 8 Feb 2025 18:11:53 +0000 Subject: [PATCH 09/28] [Inductor] Expand Identity ops prior to block pattern matching (#146000) # Feature Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s. This PR adds a few features to achieve this invariance. - Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression. - Preprocess the expression with this expansion prior to pattern matching. - Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`. # Test plan This PR adds a few new unit tests: - Added a unit test specifically for `expr.expand(identity=True)`. - Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops. I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000 Approved by: https://github.com/eellison, https://github.com/jansel --- test/inductor/test_block_analysis.py | 102 ++++++++++++++++++++++ test/test_sympy_utils.py | 12 +++ torch/_inductor/codegen/block_analysis.py | 29 +++++- torch/_inductor/codegen/triton.py | 14 +-- torch/testing/_internal/inductor_utils.py | 18 ++++ torch/utils/_sympy/functions.py | 4 + 6 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 test/inductor/test_block_analysis.py diff --git a/test/inductor/test_block_analysis.py b/test/inductor/test_block_analysis.py new file mode 100644 index 00000000000..5cf932d52e8 --- /dev/null +++ b/test/inductor/test_block_analysis.py @@ -0,0 +1,102 @@ +# Owner(s): ["module: inductor"] + +import sympy + +import torch +from torch._inductor.codegen.block_analysis import BlockPatternMatcher +from torch._inductor.virtualized import V +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.testing._internal.inductor_utils import dummy_graph +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing + + +# Some useful symbols +x, y = sympy.symbols("x y") + + +@instantiate_parametrized_tests +class BlockAnalysisTest(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Create a GraphLowering, so we can access V.graph. + cls.graph = dummy_graph() + + @parametrize( + "stride,symbol,expr", + [ + (5, x, Identity(5 * x)), + (4, y, 4 * Identity(y)), + (3, x, Identity(3) * x), + ], + ) + def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr): + # Test that we can handle an identity expression in affine indexing. + matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol) + self.assertEqual(matched_stride, stride) + + @parametrize( + "dims,strides,symbol,expr", + [ + ( + (2, 4), + (4, 1), + x, + 4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4), + ), + ( + (3, 9), + (5, 2), + x, + 5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9), + ), + ((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))), + ], + ) + def test_mod_div_identity( + self, + dims: tuple[int], + strides: tuple[int], + symbol: sympy.Symbol, + expr: sympy.Expr, + ): + # Test that we can handle an identity expression in modular indexing. + numel = int(torch.prod(torch.Tensor(dims))) + num_dims = len(dims) + with V.set_graph_handler(self.graph): + match_result = BlockPatternMatcher.match_mod_div_block_expr( + expr, symbol, numel, num_dims + ) + + # Check the matched block dimensions. + self.assertNotEqual(match_result, None) + matched_dims, matched_strides, matched_block_index_exprs = match_result + self.assertEqual(matched_dims, dims) + self.assertEqual(matched_strides, strides) + + @parametrize( + "symbol,expr,subexpr", + [ + (x, Identity(x), x), + (x, Identity(x + 5), x), + (y, Identity(x + 2 * y) + 5, 2 * y), + ], + ) + def test_subexpr_identity( + self, + symbol: sympy.Symbol, + expr: sympy.Expr, + subexpr: sympy.Expr, + ): + matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol) + self.assertEqual(matched_subexpr, subexpr) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index dddb73c2851..e804e289c1c 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import ( ) from torch.utils._sympy.functions import ( FloorDiv, + Identity, OpaqueUnaryFn_cos, simple_floordiv_gcd, ) @@ -955,6 +956,17 @@ class TestSingletonInt(TestCase): self.assertEqual(j1.free_symbols, set()) +class TestIdentity(TestCase): + def test_expand_identity(self): + """ + Test removing an identity via expansion. + """ + x = sympy.Symbol("x") + arg = x + sympy.S.One + expr = Identity(arg) + expanded = expr.expand(identity=True) + self.assertEqual(expanded.count(Identity), 0) + self.assertEqual(expanded, arg) instantiate_parametrized_tests(TestValueRanges) instantiate_parametrized_tests(TestSympyInterp) diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index 484fa135986..1c816eb8e29 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -17,8 +17,8 @@ class BlockPatternMatcher: Matches block indexing expressions. """ - @staticmethod - def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr: + @classmethod + def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: """ Given a sympy expression, return the subexpression comprised only of terms involving the specified symbol. @@ -26,6 +26,7 @@ class BlockPatternMatcher: For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`, this returns `x * 5 + x ** 2`. """ + expr = cls._preprocess(expr) return sympy.S.Zero + sum( term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols ) @@ -42,6 +43,11 @@ class BlockPatternMatcher: numels.appendleft(numel) return [*numels] + @staticmethod + def _preprocess(expr: Expr) -> Expr: + # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y. + return expr.expand(identity=True) + @classmethod def match_mod_div_block_expr( cls, @@ -54,6 +60,7 @@ class BlockPatternMatcher: Matches modular indexing expressions, converting them to implied block dimensions and strides. See triton.py for more information. """ + index = cls._preprocess(index) # Pattern match to find the strides and offset. wild = functools.partial(sympy.Wild, exclude=[index_var]) @@ -141,3 +148,21 @@ class BlockPatternMatcher: ) return dims, strides, block_index_exprs + + @classmethod + def match_affine_block_expr( + cls, + index: Expr, + index_var: Symbol, + ) -> Optional[Expr]: + """ + Matches simple expressions of the form stride * index, returning the + stride. + """ + index = cls._preprocess(index) + stride = sympy.Wild("stride", exclude=[index_var]) + m = index.match(index_var * stride) + if m is None: + return None + + return m[stride] diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e0c1f988479..5f3e1b78783 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1790,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): and self.index_dtype == "tl.int32" ): - def match_strided_block( + def match_affine_block( index: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ @@ -1799,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): This implies stride (s,), and shape (XBLOCK,). """ - symbol = range_tree.symbol() - stride = sympy.Wild("stride", exclude=[symbol]) - m = index.match(symbol * stride) - if m is None: + stride = BlockPatternMatcher.match_affine_block_expr( + index, range_tree.symbol() + ) + if stride is None: return None return BlockParameters( shape=[range_tree.numel], block_shape=[TritonSymbols.get_block_size(range_tree)], - strides=[m[stride]], + strides=[stride], offsets=[TritonSymbols.get_block_offset(range_tree)], ) @@ -1917,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): Match a block indexing subexpression involving a single range tree. """ for match_func in ( - match_strided_block, + match_affine_block, match_mod_div_block, ): match = match_func(expr, range_tree) diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 3110c3947af..13de003b330 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -10,6 +10,9 @@ import os from subprocess import CalledProcessError import sys import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch.fx.experimental.proxy_tensor import make_fx +from torch._inductor.graph import GraphLowering +from torch._inductor.compile_fx import shape_env_from_inputs from torch._inductor.codecache import CppCodeCache from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu from torch._inductor.utils import GPU_TYPES, get_gpu_type @@ -142,6 +145,21 @@ IS_H100 = LazyVal( IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) +def dummy_graph() -> GraphLowering: + """ + Create a graph. This is useful for unit testing code which accesses + V.graph.sizevars. + """ + example_inputs = [torch.randn(10) for _ in range(2)] + gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs) + shape_env = shape_env_from_inputs(example_inputs) + graph = GraphLowering( + gm, + shape_env=shape_env, + ) + + return graph + def maybe_skip_size_asserts(op): """ For certain ops, there meta and eager implementation returns differents diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index ae0a1eee398..15db18f3307 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1286,6 +1286,10 @@ class Identity(sympy.Function): def _eval_is_integer(self): return self.args[0].is_integer # type: ignore[attr-defined] + def _eval_expand_identity(self, **hints): + # Removes the identity op. + return self.args[0] + def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function): From 92b7e610abf3e3ea3cded019a0976718675d480f Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 7 Feb 2025 16:58:56 -0800 Subject: [PATCH 10/28] [Inductor changes] Invoke Quant (#139102) Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0). The primary motivations are - Unifying scattered reasoning for quant operators throughout the code base - Easy of pattern matching - see this very large pattern match expression [here](https://github.com/pytorch/pytorch/blob/949fdd299764d4fbefe1db093717786d946aaa60/torch/_inductor/fx_passes/post_grad.py#L390-L426. Compared to the pattern I have in the tests: ``` @register_graph_pattern( CallFunction( torch.ops.aten.mm, CallFunction( torch.ops.higher_order.invoke_quant, Ignored(), Ignored(), Ignored(), scheme="nf4", ), Arg(), ), pass_dict=test_pass, ) ``` - Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul. Example graph: ``` Python ===== AFTER POST GRAD ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class (torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] repeated_subgraph0 = self.repeated_subgraph0 invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None return (invoke_quant,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"): # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg] mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return add ``` The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present. I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing. ``` invoke_quant = InvokeQuant(codegen_low_precision=True) invoke_quant(gn, (x, y), scheme="nf4") ``` Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162. Feedback welcome. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102 Approved by: https://github.com/Chillee --- test/dynamo/test_higher_order_ops.py | 6 + test/higher_order_ops/test_invoke_quant.py | 183 +++++++++++++++++++++ torch/_dynamo/trace_rules.py | 1 + torch/_higher_order_ops/__init__.py | 8 + torch/_higher_order_ops/_invoke_quant.py | 72 ++++++++ torch/_inductor/fx_passes/joint_graph.py | 85 ++++++++++ torch/_inductor/lowering.py | 17 ++ torch/testing/_internal/hop_db.py | 59 +++++++ 8 files changed, 431 insertions(+) create mode 100644 test/higher_order_ops/test_invoke_quant.py create mode 100644 torch/_higher_order_ops/_invoke_quant.py diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 3704d9e5c53..6960698382d 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -7000,6 +7000,12 @@ class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase): ) def test_hops_compile(self, device, dtype, op, backend): # Ensure HOPs can be compiled + + if backend == "aot_eager" and op.name == "invoke_quant": + raise unittest.SkipTest( + "TODO: partitioner fails. migrate canonicalization to aot eager backend" + ) + sample_inputs_itr = op.sample_inputs( device, dtype, requires_grad=op.supports_autograd ) diff --git a/test/higher_order_ops/test_invoke_quant.py b/test/higher_order_ops/test_invoke_quant.py new file mode 100644 index 00000000000..96addfe1aae --- /dev/null +++ b/test/higher_order_ops/test_invoke_quant.py @@ -0,0 +1,183 @@ +# Owner(s): ["module: higher order operators"] +# flake8: noqa: B950 + +import contextlib +import logging +import unittest + +import torch +import torch._dynamo +import torch._functorch +import torch._inductor +import torch._inductor.decomposition +from torch._higher_order_ops import InvokeQuant +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + Ignored, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from torch.testing import FileCheck +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +invoke_quant_tracer = InvokeQuant() + + +@skipIfTorchDynamo("Not a torch._dynamo test") +class TestInvokeQuant(TestCase): + backend = "" + + def test_simple(self): + def gn(x, y): + return (torch.mul(x, y) + y,) + + def fn(x, y): + return invoke_quant_tracer( + gn, (x, y), scheme="nf4", quant_options=invoke_quant_tracer + )[0] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(False) + y_clone = y.clone().detach().requires_grad_(False) + res = torch.compile(fn, backend=self.backend)(x_clone, y_clone) + self.assertEqual(ref, res) + + def test_construct_inline(self): + def gn(x, y): + return (torch.mul(x, y) + y,) + + def fn(x, y): + return InvokeQuant(codegen_low_precision=False)(gn, (x, y), scheme="nf4")[0] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(False) + y_clone = y.clone().detach().requires_grad_(False) + res = torch.compile(fn, backend=self.backend)(x_clone, y_clone) + self.assertEqual(ref, res) + + def test_inline(self): + def gn(x, y): + return (torch.mul(x, y) + y,) + + def fn(x, y): + return InvokeQuant()(gn, (x, y), scheme="nf4")[0] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(False) + y_clone = y.clone().detach().requires_grad_(False) + res = torch.compile(fn, backend=self.backend)(x_clone, y_clone) + self.assertEqual(ref, res) + + def test_multiple(self): + torch._logging.set_logs(post_grad_graphs=True) + + def gn(x, y): + return torch.mul(x, y) + y + + def fn(x, y, z): + o1 = invoke_quant_tracer(gn, (x, y), scheme="nf4") + o2 = invoke_quant_tracer(gn, (y, z), scheme="nf4") + return o1 + o2 + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + z = torch.randn(8, requires_grad=False) + ref = fn(x, y, z) + + log_context = ( + contextlib.nullcontext() + if self.backend != "inductor" + else self.assertLogs(logger="torch._inductor", level=logging.DEBUG) + ) + + with log_context as log: + res = torch.compile(fn, backend=self.backend)(x, y, z) + + self.assertEqual(ref, res) + + if self.backend == "inductor": + logs = "\n".join(r.getMessage() for r in log.records) + f = FileCheck() + f.check("AFTER POST GRAD") + f.check("subgraph0").check("subgraph1") + for _ in range(2): + f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4") + f.run(logs) + + +class TestInvokeQuantEager(TestInvokeQuant): + backend = "eager" + + +class TestInvokeQuantAotEager(TestInvokeQuant): + backend = "aot_eager" + + +class TestInvokeQuantInductor(TestInvokeQuant): + backend = "inductor" + + def test_pattern_matching(self): + counter = 0 + + test_pass = PatternMatcherPass() + + def my_pass(g): + return test_pass.apply(g) + + def gn(x, y): + return torch.mul(x, y) + y + + def fn(x, y, z): + return invoke_quant_tracer(gn, (x, y), scheme="nf4") @ z + + def fn_no_match(x, y, z): + return invoke_quant_tracer(gn, (x, y)) @ z + + x = torch.randn(64, 64, requires_grad=False) + y = torch.randn(64, 64, requires_grad=False) + z = torch.randn(64, 64, requires_grad=False) + + @register_graph_pattern( + CallFunction( + torch.ops.aten.mm, + CallFunction( + torch.ops.higher_order.invoke_quant, + Ignored(), + Ignored(), + Ignored(), + scheme="nf4", + ), + Arg(), + ), + pass_dict=test_pass, + ) + def quant_matching(match: Match, *args, **kwargs): + nonlocal counter + counter += 1 + + with unittest.mock.patch( + "torch._inductor.config.post_grad_custom_pre_pass", my_pass + ): + torch.compile(fn)(x, y, z) + self.assertTrue(counter == 1) + + torch.compile(fn_no_match)(x, y, z) + self.assertTrue(counter == 1) + + +del TestInvokeQuant + +if __name__ == "__main__": + run_tests() diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 7b459ffcbb9..7245e336532 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3209,6 +3209,7 @@ LEGACY_MOD_INLINELIST = { "torch._higher_order_ops.while_loop", "torch._higher_order_ops.associative_scan", "torch._higher_order_ops.scan", + "torch._higher_order_ops._invoke_quant", "torch._higher_order_ops.utils", "torch.nn.attention.flex_attention", "torch.ao.quantization.pt2e.export_utils", diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 47dbbd941f2..c8a9da5e78d 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -1,3 +1,8 @@ +from torch._higher_order_ops._invoke_quant import ( + invoke_quant, + invoke_quant_packed, + InvokeQuant, +) from torch._higher_order_ops.aoti_call_delegate import aoti_call_delegate from torch._higher_order_ops.associative_scan import associative_scan from torch._higher_order_ops.auto_functionalize import ( @@ -51,6 +56,9 @@ __all__ = [ "executorch_call_delegate", "call_torchbind", "run_const_graph", + "InvokeQuant", + "invoke_quant", + "invoke_quant_packed", "wrap_with_set_grad_enabled", "wrap_with_autocast", "wrap_activation_checkpoint", diff --git a/torch/_higher_order_ops/_invoke_quant.py b/torch/_higher_order_ops/_invoke_quant.py new file mode 100644 index 00000000000..cfbb7b1cc55 --- /dev/null +++ b/torch/_higher_order_ops/_invoke_quant.py @@ -0,0 +1,72 @@ +# mypy: allow-untyped-defs +# need to fix prim_hop_base type annotations first + +import dataclasses +from typing import Optional + +import torch +from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase + + +class InvokeQuantTracer(PrimHOPBase): + def __init__(self) -> None: + super().__init__("invoke_quant_packed") + + def __call__(self, subgraph, operands, *, scheme=None, quant_options=None): + subgraph = FunctionWithNoFreeVars(subgraph) + return super().__call__( + subgraph, operands, scheme=scheme, quant_options=quant_options + ) + + +invoke_quant_packed = InvokeQuantTracer() + + +class InvokeQuantUnpacked(PrimHOPBase): + def __init__(self) -> None: + super().__init__("invoke_quant") + + def __call__(self, subgraph, *operands, scheme=None): + return super().__call__(subgraph, operands, scheme=scheme) + + def _call_FakeTensorMode( + self, mode, subgraph, operands, scheme: Optional[str] = None, **kwargs + ): + # TODO: this should probably route through FakeTensorMode to reuse caching + with mode: + return subgraph(*operands[0], **kwargs) + + +invoke_quant = InvokeQuantUnpacked() + + +@dataclasses.dataclass(frozen=True, repr=True) +class InvokeQuant: + """ + Invoke a quantization function that will be preserved as a single operator. Preservation + as a single operator aids in pattern matching and custom lowerings. + + The operation appears as: + torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=scheme) + + Args: + codegen_low_precision: Use observed subgraph dtypes for codegen instead of + upcasting to fp32. Can improve performance for prologue fusion but + requires careful testing of numerics. + """ + + codegen_low_precision: bool = True + + def __call__( + self, + *args, + scheme: Optional[str] = None, + **kwargs, + ): + if not torch._utils.is_compiling(): + return args[0](*args[1], **kwargs) + + if scheme is not None: + kwargs["scheme"] = scheme + + return invoke_quant_packed(*args, **kwargs, quant_options=self) # type: ignore[call-arg] diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 91998631a49..92e5a0c189f 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -2,6 +2,7 @@ import functools import itertools import logging +import operator import typing from collections import Counter from typing import Any, Union @@ -442,6 +443,81 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule): remove_redundant_views(gm) +def canonicalize_quant_mapping(gm: torch.fx.GraphModule): + """ + + + torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1)); + -> + torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); + """ + graph = gm.graph + invoke_quant_invocations = graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_quant_packed + ) + for invoke_quant in invoke_quant_invocations: + kwargs = dict(invoke_quant.kwargs) + + quant_options_node = kwargs.pop("quant_options", None) + if quant_options_node is not None: + assert isinstance(quant_options_node, torch.fx.Node) + quant_options = torch._higher_order_ops.InvokeQuant( + *invoke_quant.kwargs["quant_options"].args, + **invoke_quant.kwargs["quant_options"].kwargs, + ) + else: + quant_options = None + + subgraph, args = invoke_quant.args + with gm.graph.inserting_before(invoke_quant): + invoke_quant_replacement = graph.call_function( + torch._higher_order_ops.invoke_quant, + (subgraph, *args), + kwargs, + ) + invoke_quant_replacement.meta.update(subgraph.meta) + invoke_quant_replacement.meta["quant_options"] = quant_options + + invoke_quant.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(invoke_quant) + + if quant_options_node and len(quant_options_node.users) == 0: + graph.erase_node(quant_options_node) + + first_user = next(iter(invoke_quant_replacement.users)) + + if ( + len(invoke_quant_replacement.users) == 1 + and len(subgraph.users) == 1 + and first_user.target == operator.getitem + and first_user.args[1] == 0 + ): + subgraph_graph = getattr(gm, subgraph.target) + output_node = torch._inductor.utils.output_node(subgraph_graph) + assert ( + isinstance(output_node.args[0], (list, tuple)) + and len(output_node.args[0]) == 1 + ) + + unpacked_output = output_node.args[0][0] + output_node.args = (unpacked_output,) + if "val" in output_node.meta: + output_node.meta["val"] = output_node.meta["val"][0] + subgraph_graph.recompile() + + invoke_quant_replacement.meta.update(first_user.meta) + first_user.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(first_user) + + +def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule): + """ + Canonicalization passes that will run immediately after aot autograd + tracing. Thsis must be run before all other graph passes. + """ + canonicalize_quant_mapping(gm) + + def joint_graph_passes(graph: torch.fx.GraphModule): """ Run FX transformations on the joint forwards+backwards graph. @@ -454,6 +530,15 @@ def joint_graph_passes(graph: torch.fx.GraphModule): lazy_init() count = 0 + # must occur before other passes + canonicalize_aten_ir_passes(graph) + + if config.joint_custom_pre_pass is not None: + GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass( + config.joint_custom_pre_pass + ) + count += 1 + from .post_grad import remove_noop_ops GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 719a49312f4..3383a4e0773 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -6754,6 +6754,23 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands): return list(map(TensorBox.create, result)) +@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None) +def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None): + output = None + for i, node in enumerate(subgraph_fn.graph_module.graph.nodes): + if node.op == "placeholder": + V.graph.env[node] = operands[i] + continue + # todo getattr + elif node.op == "output": + args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + output = torch.fx.Interpreter.output(V.graph, node, args, kwargs) + else: + V.graph.env[node] = V.graph.run_node(node) + + return output + + @register_lowering(associative_scan_op, type_promotion_kind=None) def associative_scan(combine_fn: ir.Subgraph, xs): from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 9435f136183..bd326614fa1 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -11,6 +11,8 @@ from torch.testing._internal.common_device_type import onlyCUDA from torch.testing._internal.common_dtype import all_types_and, custom_types from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput from torch._higher_order_ops.invoke_subgraph import mark_compile_region +from torch._higher_order_ops import InvokeQuant, invoke_quant_packed + def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( @@ -218,6 +220,24 @@ def simple_scan(init, xs): return torch._higher_order_ops.scan(combine_fn, init, xs) +quant_tracer = InvokeQuant() + + +def simple_invoke_quant(x): + def fn(x, y): + return (torch.sin(x) * y,) + + return quant_tracer(fn, (x, x))[0] * 2. + + +def simple_invoke_quant_packed(x): + def fn(x): + return (torch.sin(x),) + + return invoke_quant_packed(fn, (x,))[0] * 2. + + + hop_db = [ OpInfo( name="scan", @@ -300,6 +320,45 @@ hop_db = [ # "torch.compile with aot_autograd does not currently support double backward." supports_gradgrad=False, ), + OpInfo( + name="invoke_quant", + variant_test_name="simple", + op=simple_invoke_quant, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + skips=( + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), + ), + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="invoke_quant_packed", + variant_test_name="simple", + op=simple_invoke_quant_packed, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), OpInfo( name="while_loop", variant_test_name="simple", From ade8fee5120fc0220dc15b6c16c8d3189601fca0 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Sat, 8 Feb 2025 22:40:14 +0000 Subject: [PATCH 11/28] Use c10 version of half/bfloat16 in executorch (#144111) Summary: X-link: https://github.com/pytorch/executorch/pull/7040 Accomplished by importing relevant files from c10 into executorch/runtime/core/portable_type/c10, and then using `using` in the top-level ExecuTorch headers. This approach should keep the ExecuTorch build hermetic for embedded use cases. In the future, we should add a CI job to ensure the c10 files stay identical to the PyTorch ones. ghstack-source-id: 260047850 exported-using-ghexport Test Plan: builds Differential Revision: D66106969 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144111 Approved by: https://github.com/malfet --- buckbuild.bzl | 1 + c10/build.bzl | 12 ++++++++++++ c10/core/build.bzl | 18 +++++++++++++++++- c10/util/build.bzl | 12 ++++++++++++ 4 files changed, 42 insertions(+), 1 deletion(-) diff --git a/buckbuild.bzl b/buckbuild.bzl index 17153b5df77..65141ac9b5a 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -997,6 +997,7 @@ def define_buck_targets( "Config.h": ":generate_aten_config[Config.h]", }, labels = labels, + visibility = ["PUBLIC"], ) fb_xplat_cxx_library( diff --git a/c10/build.bzl b/c10/build.bzl index d4192a46852..6ecae511223 100644 --- a/c10/build.bzl +++ b/c10/build.bzl @@ -22,3 +22,15 @@ def define_targets(rules): [], ), ) + + rules.cc_library( + name = "c10_headers", + deps = [ + "//c10/core:base_headers", + "//c10/macros", + "//c10/util:base_headers", + "//c10/util:bit_cast", + "//c10/util:ssize", + ], + visibility = ["//visibility:public"], + ) diff --git a/c10/core/build.bzl b/c10/core/build.bzl index 45fc5ea3390..fe9a31a2da4 100644 --- a/c10/core/build.bzl +++ b/c10/core/build.bzl @@ -90,6 +90,22 @@ def define_targets(rules): alwayslink = True, ) + rules.cc_library( + name = "base_headers", + srcs = [], + hdrs = rules.glob( + [ + "*.h", + "impl/*.h", + ], + exclude = [ + "CPUAllocator.h", + "impl/alloc_cpu.h", + ], + ), + visibility = ["//visibility:public"], + ) + rules.filegroup( name = "headers", srcs = rules.glob( @@ -101,5 +117,5 @@ def define_targets(rules): "alignment.h", ], ), - visibility = ["//c10:__pkg__"], + visibility = ["//visibility:public"], ) diff --git a/c10/util/build.bzl b/c10/util/build.bzl index a6f95ae7516..5e1dc6fbfbf 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -80,6 +80,18 @@ def define_targets(rules): ], ) + rules.cc_library( + name = "base_headers", + hdrs = rules.glob( + ["*.h"], + exclude = [ + "bit_cast.h", + "ssize.h", + ], + ), + visibility = ["//visibility:public"], + ) + rules.filegroup( name = "headers", srcs = rules.glob( From 8603a1c870356b9b62cfed410ac66d0673611c35 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:19 -0300 Subject: [PATCH 12/28] Suport generators (#141055) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141055 Approved by: https://github.com/zou3519 --- test/dynamo/test_ctx_manager.py | 3 +- test/dynamo/test_exceptions.py | 15 + test/dynamo/test_generator.py | 656 ++++++++++++++++++++++++ torch/_dynamo/codegen.py | 15 +- torch/_dynamo/config.py | 4 + torch/_dynamo/exc.py | 11 + torch/_dynamo/symbolic_convert.py | 89 +++- torch/_dynamo/trace_rules.py | 15 +- torch/_dynamo/variables/__init__.py | 2 + torch/_dynamo/variables/builtin.py | 17 +- torch/_dynamo/variables/functions.py | 226 ++++++-- torch/_dynamo/variables/misc.py | 1 + torch/_dynamo/variables/user_defined.py | 7 +- 13 files changed, 994 insertions(+), 67 deletions(-) create mode 100644 test/dynamo/test_generator.py diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 12258b956bc..bf975827cf6 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1744,10 +1744,11 @@ class GraphModule(torch.nn.Module): class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase): def setUp(self): + self._prev = torch._dynamo.config.enable_trace_contextlib torch._dynamo.config.enable_trace_contextlib = True def tearDown(self): - torch._dynamo.config.enable_trace_contextlib = False + torch._dynamo.config.enable_trace_contextlib = self._prev def test_ctx_basic0(self): @contextlib.contextmanager diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index a72b87d8f6d..94d722fc059 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -404,6 +404,21 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) + def test_raise_GeneratorExit(self): + # GeneratorExit does not inherit from Exception + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + try: + raise GeneratorExit + except Exception: + return t.sin() + except BaseException: + return t.cos() + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.cos()) + def test_speculation_exception(self): log = SpeculationLog() log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1)) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py new file mode 100644 index 00000000000..032d6a9a0ed --- /dev/null +++ b/test/dynamo/test_generator.py @@ -0,0 +1,656 @@ +# Owner(s): ["module: dynamo"] +import itertools +import unittest +from collections import OrderedDict + +import torch +import torch._dynamo.test_case +import torch._dynamo.testing +from torch._dynamo.exc import Unsupported +from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +class GeneratorTestsBase(torch._dynamo.test_case.TestCase): + def setUp(self): + super().setUp() + self._old = torch._dynamo.config.enable_faithful_generator_behavior + torch._dynamo.config.enable_faithful_generator_behavior = True + + def tearDown(self): + super().tearDown() + torch._dynamo.config.enable_faithful_generator_behavior = self._old + + def _compile_check(self, fn, args=None, fullgraph=True): + eager = EagerAndRecordGraphs() + if args is None: + args = (torch.randn(2),) + r = torch.compile(fn, backend=eager, fullgraph=fullgraph)(*args) + self.assertGreater(len(eager.graphs), 0) + return r + + +class GeneratorTests(GeneratorTestsBase): + def test_generator_simple(self): + def whoo(): + yield 1 + yield 2 + yield 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo() + t = t + next(gen) + t = t + next(gen) + t = t + next(gen) + return t + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 6) + + def test_infinite_generator(self): + def whoo(): + i = 0 + while True: + yield i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo() + t = t + next(gen) + t = t + next(gen) + t = t + next(gen) + return t + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 3) + + def test_infinite_generator_2(self): + def whoo(t): + i = 0 + while True: + yield t + i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(t))) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, list(zip(range(3), whoo(t)))) + + def test_infinite_generator_3(self): + def whoo(i): + while True: + yield i + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(1))), t.sin() + + t = torch.randn(2) + y, _ = fn(t) + self.assertEqual(y, list(zip(range(3), whoo(1)))) + + def test_graph_break_in_generator(self): + def whoo(): + yield 1 + torch._dynamo.graph_break() + yield 2 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) + def fn(t): + gen = whoo() + s = next(gen) + s += next(gen) + return t + s + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 3) + self.assertEqual(len(eager.graphs), 0) + + def test_graph_break_in_generator_2(self): + def whoo(x): + yield x.sin() + torch._dynamo.graph_break() + yield x.cos() + + def call_whoo(x): + gen = whoo(x) + sin = next(gen) + cos = next(gen) + return sin, cos + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) + def fn(t): + sin, cos = call_whoo(t) + return sin + cos + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin() + t.cos()) + self.assertEqual(len(eager.graphs), 1) + self.assertExpectedInline( + normalize_gm(eager.graphs[0].print_readable(False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_stack0_0_: "f32[2]", L_stack0_1_: "f32[2]"): + l_stack0_0_ = L_stack0_0_ + l_stack0_1_ = L_stack0_1_ + + add: "f32[2]" = l_stack0_0_ + l_stack0_1_; l_stack0_0_ = l_stack0_1_ = None + return (add,) +""", + ) + + def test_generator_as_argument(self): + # The inline tracer needs to be kept in sync if an already advanced generator + # is given to a compiled function. + def whoo(): + yield 1 + yield 2 + yield 3 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo() + next(ctx) + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_generator_as_argument_2(self): + def whoo(x): + yield x.sin() + yield x.cos() + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo(t) + next(ctx) + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_generator_as_argument_3(self): + # The inline tracer needs to be kept in sync if an already advanced generator + # is given to a compiled function. + def whoo(): + yield 1 + yield 2 + yield 3 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo() + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_generator_as_argument_4(self): + def whoo(x): + yield x.sin() + yield x.cos() + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t, ctx): + return t + next(ctx) + + t = torch.randn(2) + ctx = whoo(t) + with self.assertRaisesRegex( + Unsupported, "Generator as graph argument is not supported" + ): + fn(t, ctx) + + def test_islice_chain(self): + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=True) + def fn(t): + tmp1 = [t + 1, t + 2] + tmp2 = [t + 3, t + 4] + return list(itertools.chain(tmp1, tmp2)) + + t = torch.tensor([1.0]) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3, t + 4]) + + def test_zip_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + def fn(t): + return zip(range(3), whoo(t)), t.sin() + + t = torch.randn(2) + z, _ = self._compile_check(fn, args=(t,)) + self.assertEqual(list(z), list(zip(range(3), whoo(t)))) + + @unittest.expectedFailure + def test_zip_generator_2(self): + def bar(t, i): + return t + i + + def whoo(t): + yield bar(t, 1) + yield bar(t, 2) + yield bar(t, 3) + + def fn(t): + return zip(range(3), whoo(t)) + + t = torch.randn(3) + y = self._compile_check(fn, args=(t,), fullgraph=False) + expected = list(zip(range(3), whoo(t))) + self.assertEqual(expected, list(y)) + + @unittest.expectedFailure + def test_zip_subgenerator(self): + def subgen(t): + yield t + 1 + yield t + 2 + + def whoo(t): + yield from subgen(t) + yield t + 3 + + def fn(t): + return zip(range(3), whoo(t)), t.sin() + + t = torch.randn(2) + z, _ = self._compile_check(fn, args=(t,)) + self.assertEqual(list(z), list(zip(range(3), whoo(t)))) + + def test_list_zip_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(t))) + + t = torch.randn(3) + y = fn(t) + expected = list(zip(range(3), whoo(t))) + self.assertEqual(expected, y) + + def test_zip_infinite_generator(self): + def whoo(t): + i = 0 + while True: + yield t + i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + return list(zip(range(3), whoo(t))) + + t = torch.randn(3) + y = fn(t) + expected = list(zip(range(3), whoo(t))) + self.assertEqual(expected, y) + + @parametrize("container", [list, tuple, dict, OrderedDict]) + def test_dict_tuple_list_generator(self, container): + if container in (dict, OrderedDict): + self.skipTest("Needs __iter__") + + def whoo(t): + yield 1, t + 1 + yield 2, t + 2 + yield 3, t + 3 + + def fn(t): + gen = whoo(t) + return container(gen) + + t = torch.randn(2) + expected = fn(t) + got = torch.compile(backend="eager", fullgraph=True)(fn)(t) + self.assertEqual(expected, got) + + def test_return_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + return gen + + t = torch.tensor([1.0]) + gen = fn(t) + self.assertEqual(list(gen), [t + 1, t + 2, t + 3]) + + def test_return_tuple_generator(self): + def whoo(t): + yield t.sin() + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + g1, g2 = whoo(t), whoo(t + 1) + return (g1, g2), t.sin() + + t = torch.randn(2) + (g1, g2), _ = fn(t) + self.assertEqual(list(g1), [t.sin(), t.cos()]) + self.assertEqual(list(g2), [(t + 1).sin(), (t + 1).cos()]) + + def test_return_advanced_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + next(gen) + return gen + + t = torch.tensor([1.0]) + gen = fn(t) + self.assertEqual(list(gen), [t + 2, t + 3]) + + def test_return_exhaust_generator(self): + def whoo(t): + yield t + 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + next(gen) + return gen + + t = torch.tensor([1.0]) + gen = fn(t) + with self.assertRaises(StopIteration): + next(gen) + + @unittest.expectedFailure + def test_subgenerator(self): + def subgen(t): + yield t + 1 + yield t + 2 + + def main_gen(t): + yield from subgen(t) + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = main_gen(t) + return list(gen) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3]) + + @unittest.expectedFailure + def test_return_subgenerator(self): + def subgen(t): + yield t + 1 + yield t + 2 + + def main_gen(t): + yield from subgen(t) + yield t + 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = main_gen(t) + next(gen) + return gen + + t = torch.randn(2) + gen = fn(t) + self.assertEqual(list(gen), [t + 2, t + 3]) + + def test_dynamo_disable_generator(self): + @torch._dynamo.disable + def main_gen(t): + yield t + 1 + yield t + 2 + yield t + 3 + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = main_gen(t) + return list(gen) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3]) + + def test_dynamo_disable_sub_generator(self): + @torch._dynamo.disable + def subgen(t): + yield t + 2 + yield t + 3 + + def main_gen(t): + yield t + 1 + yield from subgen(t) + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = main_gen(t) + return list(gen) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, [t + 1, t + 2, t + 3]) + + def test_graph_break_outside_generator(self): + def whoo(t): + yield t + 1 + yield t + 2 + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + x = next(gen) + torch._dynamo.graph_break() + y = next(gen) + return x + y + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, (t + 1) + (t + 2)) + + def test_graph_break_before_calling_generator(self): + def whoo(t): + for perm in itertools.product(itertools.permutations((0, 1, 2)), repeat=1): + yield sum(perm[0]) + + def fn(t): + s = 0 + for b, p in itertools.product(whoo(t), itertools.permutations((4, 5))): + s += b + return s + + t = torch.randn(2) + expected = fn(t) + got = torch.compile(backend="eager", fullgraph=False)(fn)(t) + self.assertEqual(expected, got) + + @unittest.expectedFailure + def test_generator_with_side_effects(self): + i = 0 + + def whoo(t): + nonlocal i + for j in range(5): + i += 1 + yield t + j + + def fn(t): + return whoo(t), t.sin() + + t = torch.randn(2) + with self.assertRaises(Unsupported): + fn(t) + + @unittest.expectedFailure + def test_subgenerator_with_side_effects(self): + i = 0 + + def subgen(t): + nonlocal i + i += 1 + yield t + i += 1 + yield t + 1 + + def whoo(t): + nonlocal i + yield from subgen(t) + i += 1 + yield t + 2 + i += 1 + yield t + 3 + i += 1 + yield t + 4 + + def fn(t): + return whoo(t), t.sin() + + with self.assertRaises(Unsupported): + self._compile_check(fn) + + @unittest.expectedFailure + def test_generator_with_side_effects_graph_break(self): + i = 0 + + def whoo(t): + nonlocal i + for j in range(5): + i += 1 + yield t + j + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + torch._dynamo.graph_break() + return list(zip(range(3), gen)) + + t = torch.randn(2) + with self.assertRaises(Unsupported): + fn(t) + + def test_generator_with_side_effects_graph_break_2(self): + i = 0 + + def whoo(t): + nonlocal i + for j in range(5): + i += 1 + yield t + j + torch._dynamo.graph_break() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + return list(zip(range(3), gen)) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(i, 3) + self.assertEqual(y, [(0, t), (1, t + 1), (2, t + 2)]) + + +class GeneratorCPythonTests(GeneratorTestsBase): + # Taken from commit + # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 + # changed the tests a little bit to run them inside dynamo + # + replaced all self.assert* calls to plain assert statements + + @unittest.expectedFailure + def test_send_non_none_to_new_gen(self): + def f(): + yield 1 + + def fn(t): + g = f() + z = 0 + try: + g.send(0) + except TypeError: + z += 1 + except Exception as e: + raise AssertionError from e + assert z == 1 + assert next(g) == 1 + return t.sin() + + self._compile_check(fn) + + @unittest.expectedFailure + def test_issue103488(self): + def gen_raises(): + yield 1 + raise ValueError + + def loop(): + try: + for _ in gen_raises(): + if True is False: # noqa: PLR0133 + return + except ValueError: + pass + + def fn(t): + # This should not raise + loop() + return t.sin() + + self._compile_check(fn) + + +instantiate_parametrized_tests(GeneratorTests) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 893e7056281..2aed88f713b 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -26,7 +26,10 @@ from .exc import IncorrectUsage, unimplemented from .source import AttrSource, Source from .utils import is_safe_constant, rot_n_helper from .variables.base import ValueMutationExisting, VariableTracker -from .variables.functions import FunctionDecoratedByContextlibContextManagerVariable +from .variables.functions import ( + ContextlibContextManagerLocalGeneratorObjectVariable, + LocalGeneratorObjectVariable, +) from .variables.nn_module import NNModuleVariable from .variables.tensor import ( NumpyNdarrayVariable, @@ -162,14 +165,20 @@ class PyCodegen: return if value.is_realized() and isinstance( - value, FunctionDecoratedByContextlibContextManagerVariable + value, ContextlibContextManagerLocalGeneratorObjectVariable ): raise IncorrectUsage( "NYI: Returning a @contextmanager object from a torch.compile function" ) # Dynamo normally prefers codegen from source to account for aliasing. - if value.source is not None and allow_cache: + if ( + value.source is not None + and allow_cache + and not ( + value.is_realized() and isinstance(value, LocalGeneratorObjectVariable) + ) + ): # There's a corner case for export: for instance, if the computation # graph is just identity on an input tensor, Dynamo would just emit # a `LOAD_FAST` from the input source, rather than generating an diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index dcc4de1e7f0..c3b57afc881 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -417,6 +417,10 @@ enable_cpp_symbolic_shape_guards = False # Enable tracing through contextlib.contextmanager enable_trace_contextlib = True +# Enable tracing generator functions lazily. If False, Dynamo will exhaust +# generators upon first execution. And if True, the generator will be accessed lazily +enable_faithful_generator_behavior = False + # Inline inbuilt nn modules inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated] default=True, diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index eb62d4d30db..74e714692a0 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -162,6 +162,12 @@ class AttributeMutationError(Unsupported): super().__init__(msg) +class InfiniteGeneratorError(Unsupported): + # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT + def __init__(self, msg: str) -> None: + super().__init__(msg) + + class CondOpArgsMismatchError(ArgsMismatchError): """ Internal error from cond() due to arguments mismatch. @@ -267,6 +273,10 @@ class ObservedKeyError(ObservedLookupError): pass +class ObservedGeneratorExit(ObservedException): + pass + + class ObservedAttributeError(ObservedException): # An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__ pass @@ -284,6 +294,7 @@ observed_exception_map = { StopIteration: ObservedUserStopIteration, LookupError: ObservedLookupError, IndexError: ObservedIndexError, + GeneratorExit: ObservedGeneratorExit, KeyError: ObservedKeyError, AttributeError: ObservedAttributeError, RuntimeError: ObservedRuntimeError, diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 8e2b1bfa61c..33b183da576 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -85,7 +85,8 @@ from .variables.ctx_manager import ( from .variables.dicts import ConstDictVariable, SetVariable from .variables.functions import ( BaseUserFunctionVariable, - FunctionDecoratedByContextlibContextManagerVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, NestedUserFunctionVariable, SkipFunctionVariable, UserFunctionVariable, @@ -922,11 +923,22 @@ class InstructionTranslatorBase( raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] + def inline_generator_function(self, fn, args, kwargs): + """ + Redirect the call to the generator "call_function" + """ + if not isinstance(fn, LocalGeneratorFunctionVariable): + fn = LocalGeneratorFunctionVariable(fn) + return fn.call_function(self, args, kwargs) + def inline_user_function_return(self, fn, args, kwargs): """ A call to some user defined function by inlining it. """ - return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): + return self.inline_generator_function(fn, args, kwargs) + else: + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) def get_line_of_code_header(self, lineno=None): if lineno is None: @@ -3083,7 +3095,20 @@ class InstructionTranslator(InstructionTranslatorBase): return True return False + def replace_tos_if_return_is_generator(self): + if ( + len(self.stack) + and (tos := self.stack[-1]) + and isinstance(tos, LocalGeneratorObjectVariable) + ): + self.stack[-1] = ListIteratorVariable( + tos.force_unpack_var_sequence(self), + mutation_type=ValueMutationNew(), + ) + def _return(self, inst): + self.replace_tos_if_return_is_generator() + if ( not config.allow_empty_graphs and self.output.count_calls() == 0 @@ -3093,6 +3118,7 @@ class InstructionTranslator(InstructionTranslatorBase): and not self.one_graph ): raise exc.SkipFrame("because no content in function call") + self.instruction_pointer = None _step_logger()( logging.INFO, @@ -3179,8 +3205,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase): func: VariableTracker, args: list[VariableTracker], kwargs, - *, - stop_generator_on_yield: bool = False, ): if isinstance(func, SkipFunctionVariable): unimplemented("inline with functions in skip files") @@ -3189,7 +3213,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase): ( UserFunctionVariable, NestedUserFunctionVariable, - FunctionDecoratedByContextlibContextManagerVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, ), ) result = InliningInstructionTranslator.check_inlineable(func) @@ -3254,9 +3279,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase): parent.symbolic_globals, parent.symbolic_torch_function_state, func, - stop_generator_on_yield=stop_generator_on_yield, ) else: + # need the line below to make MyPy happy + assert not isinstance(func, LocalGeneratorObjectVariable) tracer = InliningInstructionTranslator( parent, code, @@ -3302,24 +3328,32 @@ class InliningInstructionTranslator(InstructionTranslatorBase): log.debug("DONE INLINING %s", code) - if is_generator(code): - assert isinstance(self, InliningGeneratorInstructionTranslator) - # The first flag tells us if we consume generators lazily or not - # and the second is if the generator is exhausted. - # In the future, generators should be lazily consumed and the first - # flag (stop_generator_on_yield) will not be needed. - if self.stop_generator_on_yield and self.generator_exhausted: + if config.enable_faithful_generator_behavior or ( + isinstance(self, InliningGeneratorInstructionTranslator) + and self.is_generator_from_ctx_manager + ): + if ( + is_generator(code) + and isinstance(self, InliningGeneratorInstructionTranslator) + and self.generator_exhausted + ): + assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration r = self.symbolic_result assert r.as_python_constant() is None exc.raise_observed_exception(StopIteration, self) else: + return self.symbolic_result + else: + if is_generator(code): + assert isinstance(self, InliningGeneratorInstructionTranslator) + assert self.symbolic_result.as_python_constant() is None return ListIteratorVariable( self.generated_items, mutation_type=ValueMutationNew(), ) - else: - return self.symbolic_result + else: + return self.symbolic_result def __init__( self, @@ -3438,27 +3472,26 @@ class InliningInstructionTranslator(InstructionTranslatorBase): class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): generated_items: list[VariableTracker] # Flag wether or not the InlineGenerator should consume the entire iterator - stop_generator_on_yield: bool - def __init__(self, *args, stop_generator_on_yield: bool = False, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.generated_items = [] - # In the future, generators should run lazily (i.e. when next(...) is called) - # TODO: Set this to True by default, so that dynamo follows CPython more - # closely - self.stop_generator_on_yield = stop_generator_on_yield self.generator_exhausted = False + self.is_generator_from_ctx_manager = False def YIELD_VALUE(self, inst: Instruction): top = self.pop() self.generated_items.append(top) if len(self.generated_items) > MAX_ITERATOR_LIMIT: - unimplemented( + raise exc.InfiniteGeneratorError( "Too many yield values in generator. Maybe you are inlining an infinite generator. " f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}", ) self.push(ConstantVariable.create(None)) - if self.stop_generator_on_yield: + if ( + config.enable_faithful_generator_behavior + or self.is_generator_from_ctx_manager + ): self.symbolic_result = top # Stop tracing raise YieldValueOp @@ -3500,10 +3533,6 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): self.pop() self.push(ConstantVariable.create(ex.value)) else: - self.push(val) - # Add the value to yield into generated_items and replace the top of the stack with None - self.YIELD_VALUE(inst) - # Repeat the YIELD_FROM instruction in the next eval loop assert ( isinstance(self.instruction_pointer, int) @@ -3511,11 +3540,15 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): ) self.instruction_pointer -= 1 + self.push(val) + # Add the value to yield into generated_items and replace the top of the stack with None + self.YIELD_VALUE(inst) + def SEND(self, inst): assert len(self.stack) >= 2 val = self.pop() tos = self.stack[-1] - if isinstance(tos, ListIteratorVariable) or ( + if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or ( isinstance(tos, UserDefinedObjectVariable) and isinstance(tos.value, collections.abc.Iterator) ): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 7245e336532..d06ffccfc4c 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -32,8 +32,9 @@ from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper from .variables import ( BuiltinVariable, FunctionalCallVariable, - FunctionDecoratedByContextlibContextManagerVariable, FunctorchHigherOrderVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, SkipFunctionVariable, @@ -3620,7 +3621,8 @@ def check_verbose(obj, is_inlined_call=False): UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable, - FunctionDecoratedByContextlibContextManagerVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, ), ): try: @@ -3640,7 +3642,14 @@ def check_verbose(obj, is_inlined_call=False): # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: set[str] = set() rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) - if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): + if issubclass( + rule, + ( + UserFunctionVariable, + LocalGeneratorFunctionVariable, + PolyfilledFunctionVariable, + ), + ): return SkipResult( False, f"inlined according trace_rules.lookup {reasons.pop()}", diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 9fc28fe50a6..ba7a10267e2 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -39,6 +39,8 @@ from .functions import ( FunctionDecoratedByContextlibContextManagerVariable, FunctoolsPartialVariable, FunctoolsWrapsVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, SkipFunctionVariable, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 6e411dfe50d..709c76f74d7 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -773,7 +773,13 @@ class BuiltinVariable(VariableTracker): tx, [v.realize() for v in args], kwargs ) - if inspect.isclass(fn) and issubclass(fn, Exception): + if inspect.isclass(fn) and ( + issubclass(fn, Exception) + # GeneratorExit doens't inherit from Exception + # >>> issubclass(GeneratorExit, Exception) + # False + or fn is GeneratorExit + ): def create_exception_class_object( tx: "InstructionTranslator", args, kwargs @@ -1425,6 +1431,13 @@ class BuiltinVariable(VariableTracker): mutation_type=ValueMutationNew(), ) + def _call_iter_tuple_generator(self, tx, obj, *args, **kwargs): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), # exhaust generator + mutation_type=ValueMutationNew(), + ) + def _call_tuple_list(self, tx, obj=None, *args, **kwargs): if isinstance(obj, variables.IteratorVariable): cls = variables.BaseListVariable.cls_for(self.fn) @@ -1432,6 +1445,8 @@ class BuiltinVariable(VariableTracker): list(obj.force_unpack_var_sequence(tx)), mutation_type=ValueMutationNew(), ) + elif isinstance(obj, variables.LocalGeneratorObjectVariable): + return self._call_iter_tuple_generator(tx, obj, *args, **kwargs) else: return self._call_iter_tuple_list(tx, obj, *args, **kwargs) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 29a3cb18abd..785c075f380 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -6,15 +6,24 @@ import inspect import itertools import types from collections.abc import Sequence -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar from typing_extensions import Never from unittest.mock import patch import torch from .. import polyfills, variables -from ..bytecode_transformation import create_call_function, create_rot_n -from ..exc import raise_observed_exception, unimplemented, Unsupported +from ..bytecode_transformation import create_call_function, create_rot_n, is_generator +from ..exc import ( + handle_observed_exception, + InfiniteGeneratorError, + ObservedException, + ObservedUserStopIteration, + raise_observed_exception, + SkipFrame, + unimplemented, + Unsupported, +) from ..guards import GuardBuilder, install_guard from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource from ..utils import ( @@ -378,61 +387,218 @@ class BuiltinMethodVariable(BaseUserFunctionVariable): return obj_vt.call_method(tx, name, args, kwargs) -class FunctionDecoratedByContextlibContextManagerVariable(BaseUserFunctionVariable): - # TODO(guilherme): replace this with a generic GeneratorFunctionVariable +class LocalGeneratorObjectVariable(VariableTracker): + def __init__( + self, + code: types.CodeType, + f_globals, + inline_tracer: Optional["InstructionTranslator"], + **kwargs, + ): + super().__init__(**kwargs) + self.code = code + self.f_globals = f_globals + self.inline_tracer = inline_tracer + def get_code(self): + return self.code + + def get_filename(self): + return self.get_code().co_filename + + def get_name(self): + return self.get_code().co_name + + def get_function(self): + raise NotImplementedError + + def has_self(self): + return False + + def __name__(self): + return self.get_name() + + def __str__(self): + return f"{self.__class__.__name__}({self.get_name()})" + + __repr__ = __str__ + + def reconstruct(self, codegen): + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + tracer = self._get_inline_tracer(tx) + try: + prev = tx.output.should_exit + tx.output.should_exit = False + if not tracer.generator_exhausted: + self.remaining_items = self.force_unpack_var_sequence(tx) + variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) + finally: + tx.output.should_exit = prev + + def bind_args(self, tx, args, kwargs): + return self.fn.bind_args(tx, args, kwargs) + + def get_globals(self): + return self.f_globals + + def python_type(self): + return types.GeneratorType + + def _get_inline_tracer(self, tx): + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + if self.inline_tracer is None: + self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( + tx, self, [], {} + ) + return self.inline_tracer + + def next_variable(self, tx): + tracer = self._get_inline_tracer(tx) + + try: + # Hierarchically, tx can be seen as the parent of the inline tracer + # created on call_function. Any exception needs to be propagated to tx + # for Dynamo to behave correctly + with patch.dict(counters, {"unimplemented": counters["inline_call"]}): + return tracer.inline_call_() + except ObservedException as e: + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + raise e + except InfiniteGeneratorError: + # test/dynamo/test_misc.py::test_iterator_limit + raise + except Unsupported as e: + torch._C._dynamo.eval_frame.skip_code(self.get_code()) + raise SkipFrame from e + + def has_unpack_var_sequence(self, tx): + return False + + def has_force_unpack_var_sequence(self, tx) -> builtins.bool: + return True + + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + result = [] + while True: + try: + result.append(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + return result + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__next__": + return self.next_variable(tx) + + super().call_method(tx, name, args, kwargs) + + +class ContextlibContextManagerLocalGeneratorObjectVariable( + LocalGeneratorObjectVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + + It is a special case of a generator function as we do not allow return a context manager + from a torch.compile function. + """ + + +class LocalGeneratorFunctionVariable(BaseUserFunctionVariable): """functions that behaves like iterators .. note:: - This is only used when the function is annotated with @contextlib.contextmanager + This is a wrapper around (Nested)UserFunctionVariable """ - def __init__(self, vt: VariableTracker, **kwargs): + def __init__( + self, + vt: VariableTracker, + *, + generator_cls=LocalGeneratorObjectVariable, + **kwargs, + ): super().__init__(**kwargs) self.vt = vt - self.inline_tracer = None + self.generator_cls = generator_cls def __getattr__(self, name): if name in self.__class__.__dict__.keys(): return getattr(self, name) return getattr(self.vt, name) + def _build_inline_tracer(self, tx, args, kwargs): + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + return InliningInstructionTranslator.build_inline_tracer( + tx, + self, + args, + kwargs, + ) + def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - from torch._dynamo.bytecode_transformation import is_generator + assert is_generator(self.vt.get_code()) - assert is_generator(self.get_code()) - from torch._dynamo.symbolic_convert import InliningInstructionTranslator + inline_tracer = self._build_inline_tracer(tx, args, kwargs) + code = self.vt.get_code() + f_globals = self.vt.get_globals() - self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( - tx, - self, - [*self.self_args(), *args], - kwargs, - stop_generator_on_yield=True, + # calling a generator returns a generator object + return self.generator_cls( + code, + f_globals, + inline_tracer, + source=self.source, ) - return self - def next_variable(self, tx): - from torch._dynamo import exc +class FunctionDecoratedByContextlibContextManagerVariable( + LocalGeneratorFunctionVariable +): + """ + .. note:: - tracer = self.inline_tracer + This is only used when the function is annotated with @contextlib.contextmanager + """ - try: - # Hierarchically, tx can be seen as the parent of the inline tracer - # created on call_function. Any exception needs to be propagated to tx - # for Dynamo to behave correctly - with patch.dict(counters, {"unimplemented": counters["inline_call"]}): - return tracer.inline_call_().next_variable(tx) - except exc.ObservedException as e: - tx.exn_vt_stack.extend(tracer.exn_vt_stack) - raise e + def __init__(self, vt, **kwargs): + super().__init__( + vt, + generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, + **kwargs, + ) + + def _build_inline_tracer(self, tx, args, kwargs): + # NOTE: This only exists to not break support for context manager when + # config.enable_faithful_generator_behavior = False and + # config.enable_trace_contextlib = True. In case the former is false, + # Dynamo should still be able to trace through @contextmanager functions + tracer = super()._build_inline_tracer(tx, args, kwargs) + assert isinstance( + tracer, + torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator, + ) + tracer.is_generator_from_ctx_manager = True + return tracer class UserMethodVariable(UserFunctionVariable): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 3c55c1b2afc..bcadb5941ff 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -234,6 +234,7 @@ class SuperVariable(VariableTracker): class ExceptionVariable(VariableTracker): + # The ExceptionVariable corresponds to the BaseException class in Python def __init__(self, exc_type, args, **kwargs) -> None: super().__init__(**kwargs) self.exc_type = exc_type diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index a9dd87d3046..fe646d35b05 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -493,7 +493,7 @@ class UserDefinedClassVariable(UserDefinedVariable): ): if not torch._dynamo.config.enable_trace_contextlib: unimplemented("contextlib.contextmanager") - # Replace UserFunctionVariable by FunctionDecoratedBycontextlibContextManagerVariable + # Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable # if the function is annotated with @contextlib.contextmanager # This shouldn't be necessary once generator functions are fully # supported in dynamo @@ -805,6 +805,11 @@ class UserDefinedObjectVariable(UserDefinedVariable): # of the cmp_eq polyfill function. return ConstantVariable.create(self.value is other.value) + if torch._dynamo.config.enable_faithful_generator_behavior and isinstance( + self.value, types.GeneratorType + ): + unimplemented("Generator as graph argument is not supported") + # check for methods implemented in C++ if isinstance(method, types.FunctionType): source = ( From d798831167442a586c7236aa6ec6d12f9a0360b9 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:19 -0300 Subject: [PATCH 13/28] Implement `generator.__iter__()` (#144421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144421 Approved by: https://github.com/zou3519 ghstack dependencies: #141055 --- test/dynamo/test_generator.py | 27 ++++++++++++++++++++------- torch/_dynamo/variables/functions.py | 3 +++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 032d6a9a0ed..c91d81398a5 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -279,7 +279,6 @@ class GraphModule(torch.nn.Module): expected = list(zip(range(3), whoo(t))) self.assertEqual(expected, list(y)) - @unittest.expectedFailure def test_zip_subgenerator(self): def subgen(t): yield t + 1 @@ -329,9 +328,6 @@ class GraphModule(torch.nn.Module): @parametrize("container", [list, tuple, dict, OrderedDict]) def test_dict_tuple_list_generator(self, container): - if container in (dict, OrderedDict): - self.skipTest("Needs __iter__") - def whoo(t): yield 1, t + 1 yield 2, t + 2 @@ -407,7 +403,6 @@ class GraphModule(torch.nn.Module): with self.assertRaises(StopIteration): next(gen) - @unittest.expectedFailure def test_subgenerator(self): def subgen(t): yield t + 1 @@ -426,7 +421,6 @@ class GraphModule(torch.nn.Module): y = fn(t) self.assertEqual(y, [t + 1, t + 2, t + 3]) - @unittest.expectedFailure def test_return_subgenerator(self): def subgen(t): yield t + 1 @@ -598,6 +592,26 @@ class GraphModule(torch.nn.Module): self.assertEqual(i, 3) self.assertEqual(y, [(0, t), (1, t + 1), (2, t + 2)]) + def test_iter(self): + def whoo(): + i = 0 + while True: + yield i + i += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + s = 0 + for i in whoo(): + if i > 5: + break + s += i + return t + s + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + sum(range(6))) + class GeneratorCPythonTests(GeneratorTestsBase): # Taken from commit @@ -625,7 +639,6 @@ class GeneratorCPythonTests(GeneratorTestsBase): self._compile_check(fn) - @unittest.expectedFailure def test_issue103488(self): def gen_raises(): yield 1 diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 785c075f380..bf5140a3a5b 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -499,6 +499,9 @@ class LocalGeneratorObjectVariable(VariableTracker): ) -> "VariableTracker": if name == "__next__": return self.next_variable(tx) + elif name == "__iter__": + # iter(gen) returns itself + return self super().call_method(tx, name, args, kwargs) From ca9b16e070003cc650db8c3977e71da2cbf8065c Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:20 -0300 Subject: [PATCH 14/28] Implement `generator.send(..)` (#144422) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144422 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421 --- test/dynamo/test_generator.py | 41 +++++++++++++++++++++++++++- torch/_dynamo/exc.py | 6 ++++ torch/_dynamo/variables/functions.py | 18 ++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index c91d81398a5..99388757e06 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -613,13 +613,51 @@ class GraphModule(torch.nn.Module): self.assertEqual(y, t + sum(range(6))) +class TestGeneratorSend(GeneratorTestsBase): + def test_send(self): + def double(): + x = yield + yield x * 2 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = double() + next(gen) + return gen.send(t) + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t * 2) + + @parametrize("fullgraph", [True, False]) + def test_send_stop_iteration(self, fullgraph): + def double(): + x = yield + yield x * 2 + + @torch.compile(backend="eager", fullgraph=fullgraph) + def fn(t): + gen = double() + next(gen) + a = gen.send(t) + b = gen.send(t) # should result in StopIteration + return a + b + + t = torch.randn(2) + if fullgraph: + with self.assertRaisesRegex(Unsupported, "Observed exception"): + fn(t) + else: + with self.assertRaises(StopIteration): + fn(t) + + class GeneratorCPythonTests(GeneratorTestsBase): # Taken from commit # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 # changed the tests a little bit to run them inside dynamo # + replaced all self.assert* calls to plain assert statements - @unittest.expectedFailure def test_send_non_none_to_new_gen(self): def f(): yield 1 @@ -661,6 +699,7 @@ class GeneratorCPythonTests(GeneratorTestsBase): instantiate_parametrized_tests(GeneratorTests) +instantiate_parametrized_tests(TestGeneratorSend) if __name__ == "__main__": diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 74e714692a0..516889821eb 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -290,6 +290,11 @@ class ObservedNotImplementedError(ObservedException): pass +class ObservedTypeError(ObservedException): + # A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method + pass + + observed_exception_map = { StopIteration: ObservedUserStopIteration, LookupError: ObservedLookupError, @@ -299,6 +304,7 @@ observed_exception_map = { AttributeError: ObservedAttributeError, RuntimeError: ObservedRuntimeError, NotImplementedError: ObservedNotImplementedError, + TypeError: ObservedTypeError, } diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index bf5140a3a5b..d64ae10e3a4 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -490,6 +490,9 @@ class LocalGeneratorObjectVariable(VariableTracker): break return result + def _is_generator_just_started(self): + return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 + def call_method( self, tx: "InstructionTranslator", @@ -502,6 +505,21 @@ class LocalGeneratorObjectVariable(VariableTracker): elif name == "__iter__": # iter(gen) returns itself return self + elif name == "send": + # Sends a value into the generator function. Returns the next value + # yielded by the generator, or raises StopIteration if the generator + # exits without yielding another value + if self._is_generator_just_started() and len(args): + # can't send non-None value to a just-started generator + # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen + if not all( + isinstance(arg, ConstantVariable) and arg.value is None + for arg in args + ): + raise_observed_exception(TypeError, tx) + tracer = self._get_inline_tracer(tx) + tracer.push_many(args) + return self.next_variable(tx) super().call_method(tx, name, args, kwargs) From 8ee095f7c1ec976370346fbf4c99e3cda1d2f648 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:20 -0300 Subject: [PATCH 15/28] Implement `generator.close()` (#144423) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144423 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421, #144422 --- test/dynamo/test_generator.py | 467 +++++++++++++++++++++++++++ torch/_dynamo/exc.py | 1 + torch/_dynamo/symbolic_convert.py | 2 - torch/_dynamo/variables/functions.py | 84 +++++ 4 files changed, 552 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 99388757e06..8aa86eb4bb3 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import itertools +import sys import unittest from collections import OrderedDict @@ -652,6 +653,471 @@ class TestGeneratorSend(GeneratorTestsBase): fn(t) +class TestGeneratorClose(GeneratorTestsBase): + def test_close(self): + def whoo(t): + yield t.sin() + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + + @unittest.expectedFailure + @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") + def test_close_subgen(self): + z = 0 + + def subgen(t): + nonlocal z + z = 1 + yield t.sin() + z = 3 + yield t.cos() + + def whoo(t): + yield from subgen(t) + yield t.tan() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(z, 1) + + def test_close_with_side_effects(self): + L = [] + z = 0 + + def whoo(t): + nonlocal z + try: + L.append(1) + yield t.sin() + L.append(2) + yield t.cos() + finally: + L.append(z) + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + z = -123 + gen.close() + L.append(len(L)) + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(L, [1, -123, 2]) + + def test_close_capture_GeneratorExit_return(self): + z = 0 + + def whoo(t): + nonlocal z + try: + z += 1 + yield t.sin() + yield t.cos() + except GeneratorExit: + z += 10 + return t.tan() # noqa: B901 + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + y = gen.close() + return (i, y) + + t = torch.randn(2) + (i, y) = fn(t) + self.assertEqual(i, t.sin()) + self.assertEqual(y, t.tan()) + self.assertEqual(z, 111) + + @parametrize("fullgraph", [True, False]) + def test_close_capture_GeneratorExit(self, fullgraph): + z = 0 + + def whoo(t): + nonlocal z + try: + yield t.sin() + yield t.cos() + except GeneratorExit: + yield t.tan() + finally: + z = 1 + + @torch.compile(backend="eager", fullgraph=fullgraph) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + if fullgraph: + # This should actually be RuntimeError("generator ignored GeneratorExit") + # but Dynamo swallow the exception and raises Unsupported instead + with self.assertRaisesRegex(Unsupported, "Observed exception"): + fn(t) + else: + with self.assertRaisesRegex( + RuntimeError, "generator ignored GeneratorExit" + ): + fn(t) + + def test_close_capture_and_reraise_GeneratorExit(self): + L = [] + z = 0 + + def whoo(t): + nonlocal z + try: + L.append(1) + yield t.sin() + yield t.cos() + except GeneratorExit: + L.append(z) + z = -1 + raise + finally: + L.append(z) + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + z = -123 + gen.close() + L.append(456) + return i + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(L, [1, -123, -1, 456]) + + @parametrize("exc", [RuntimeError, AttributeError]) + def test_close_capture_and_reraise_exc(self, exc): + def whoo(t): + try: + yield t.sin() + yield t.cos() + except GeneratorExit as e: + raise exc from e + finally: + pass + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return i + + t = torch.randn(2) + with self.assertRaises(exc): + fn(t) + + @unittest.expectedFailure + @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") + def test_close_with_subgen(self): + L = [] + z = 0 + + def subgen(t): + yield t.sin() + yield t.cos() + + def whoo(t): + nonlocal z + L.append(10) + yield from subgen(t) + L.append(20) + try: + L.append(1) + z = 4 + yield t.tan() + finally: + L.append(z) + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + nonlocal z + gen = whoo(t) + i = next(gen) + z = -123 + gen.close() + L.append(456) + return i, t.sin() + + t = torch.randn(2) + y, _ = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(L, [10, 456]) + self.assertEqual(z, -123) + + def test_close_after_close(self): + z = 0 + + def whoo(t): + nonlocal z + try: + z += 1 + yield t.sin() + yield t.cos() + finally: + # finally should only be executed once + z += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + i = next(gen) + gen.close() + return (i, gen.close()) + + t = torch.randn(2) + (i, y) = fn(t) + self.assertEqual(i, t.sin()) + self.assertEqual(y, None) + self.assertEqual(z, 2) + + @parametrize("fullgraph", [True, False]) + def test_next_after_close(self, fullgraph): + def whoo(t): + yield t.sin() + yield t.cos() + + @torch.compile(backend="eager", fullgraph=fullgraph) + def fn(t): + gen = whoo(t) + gen.close() + a = next(gen) + return [t.sin(), a] + + t = torch.randn(3) + if fullgraph: + with self.assertRaises(Unsupported): + fn(t) + else: + with self.assertRaises(StopIteration): + fn(t) + + def test_close_after_exception(self): + def whoo(t): + raise ValueError("foo") + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + try: + next(gen) + except ValueError: + pass + b = gen.close() + return [t.sin(), b] + + t = torch.randn(2) + y, b = fn(t) + self.assertEqual(y, t.sin()) + self.assertIsNone(b) + + def test_close_handling_finally(self): + z = 0 + + def whoo(t): + nonlocal z + try: + yield t.sin() + yield t.cos() + except GeneratorExit: + z += 1 + return t.tan() # noqa: B901 + finally: + z += 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + next(gen) + b = gen.close() + return t.sin(), b + + t = torch.randn(2) + y, b = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(b, t.tan()) + self.assertEqual(z, 2) + + +class GeneratorCloseCPythonTests(GeneratorTestsBase): + # Taken from commit + # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 + # changed the tests a little bit to run them inside dynamo + # + replaced all self.assert* calls to plain assert statements + + def test_close_no_return_value(self): + def f(): + yield + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_return_value(self): + def f(): + try: + yield + # close() raises GeneratorExit here, which is caught + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() == 0 + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_not_catching_exit(self): + def f(): + yield + # close() raises GeneratorExit here, which isn't caught and + # therefore propagates -- no return value + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_not_started(self): + def f(): + try: + yield + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_exhausted(self): + def f(): + try: + yield + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + next(gen) + z = 0 + try: + next(gen) # -> StopIteration + except StopIteration: + z = 1 + except Exception as e: + # anything other than StopIteration should fail + raise AssertionError from e + assert z == 1 + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_closed(self): + def f(): + try: + yield + except GeneratorExit: + return 0 # noqa: B901 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + assert gen.close() == 0 + assert gen.close() is None + return t.sin() + + t = torch.randn(2) + fn(t) + + def test_close_raises(self): + def f(): + try: + yield + except GeneratorExit: + pass + raise RuntimeError + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = f() + gen.send(None) + z = 0 + try: + gen.close() # -> RuntimeError + except RuntimeError: + z = 1 + except Exception as e: + raise AssertionError from e + assert z == 1 + return t.sin() + + t = torch.randn(2) + fn(t) + + class GeneratorCPythonTests(GeneratorTestsBase): # Taken from commit # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 @@ -700,6 +1166,7 @@ class GeneratorCPythonTests(GeneratorTestsBase): instantiate_parametrized_tests(GeneratorTests) instantiate_parametrized_tests(TestGeneratorSend) +instantiate_parametrized_tests(TestGeneratorClose) if __name__ == "__main__": diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 516889821eb..a2f5f938bdc 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -283,6 +283,7 @@ class ObservedAttributeError(ObservedException): class ObservedRuntimeError(ObservedException): + # A RuntimeError exception to be raised from inside Dynamo tracing. This can happen on generator.throw(..) method pass diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 33b183da576..fd39e76032e 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3339,8 +3339,6 @@ class InliningInstructionTranslator(InstructionTranslatorBase): ): assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration - r = self.symbolic_result - assert r.as_python_constant() is None exc.raise_observed_exception(StopIteration, self) else: return self.symbolic_result diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index d64ae10e3a4..293ba87a516 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -4,6 +4,7 @@ import builtins import functools import inspect import itertools +import sys import types from collections.abc import Sequence from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar @@ -18,6 +19,7 @@ from ..exc import ( handle_observed_exception, InfiniteGeneratorError, ObservedException, + ObservedGeneratorExit, ObservedUserStopIteration, raise_observed_exception, SkipFrame, @@ -458,6 +460,9 @@ class LocalGeneratorObjectVariable(VariableTracker): def next_variable(self, tx): tracer = self._get_inline_tracer(tx) + if self._is_generator_exhausted(): + raise_observed_exception(StopIteration, tx) + try: # Hierarchically, tx can be seen as the parent of the inline tracer # created on call_function. Any exception needs to be propagated to tx @@ -490,9 +495,22 @@ class LocalGeneratorObjectVariable(VariableTracker): break return result + def _setup_exception(self, tx, exc): + tracer = self._get_inline_tracer(tx) + tracer.push(exc) + try: + tracer._raise_exception_variable(None) + except ObservedException as e: + # if no handler is available (i.e. user code doesn't catch it), the + # exception is raised again. + tracer.exception_handler(e) + def _is_generator_just_started(self): return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 + def _is_generator_exhausted(self): + return getattr(self.inline_tracer, "generator_exhausted", False) + def call_method( self, tx: "InstructionTranslator", @@ -520,6 +538,72 @@ class LocalGeneratorObjectVariable(VariableTracker): tracer = self._get_inline_tracer(tx) tracer.push_many(args) return self.next_variable(tx) + elif name == "close": + # * Raises a GeneratorExit at the point where the generator function was paused. + # * If the generator function catches the exception and returns a + # value, this value is returned from close() - Python 3.13+ + # * If the generator function is already closed, or raises GeneratorExit + # (by not catching the exception), close() returns None. + # * If the generator yields a value, a RuntimeError is raised. + # * If the generator raises any other exception, it is propagated to the caller. + # * If the generator has already exited due to an exception or normal + # exit, close() returns None and has no other effect. + + # Return None if close is called on a just-started generator + # See test GeneratorCloseCpythonTests::test_close_not_started + + tracer = self._get_inline_tracer(tx) + if self._is_generator_just_started() or self._is_generator_exhausted(): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + # Raise GeneratorExit to see if user code catches it. Any other exception + # is propagated to the parent frame. + try: + self._setup_exception( + tx, variables.ExceptionVariable(GeneratorExit, ()) + ) + # There's an extra block on Python 3.12+ to handle StopIteration + # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397 + # + # 1 0 RETURN_GENERATOR + # 2 POP_TOP + # 4 RESUME 0 + + # 2 6 LOAD_CONST 1 (1) + # 8 YIELD_VALUE 1 + # 10 RESUME 1 + # 12 POP_TOP + # 14 RETURN_CONST 0 (None) + # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) + # 18 RERAISE 1 + # ExceptionTable: + # 4 to 14 -> 16 [0] lasti + if ( + sys.version_info >= (3, 12) + and tracer.next_instruction.opname == "CALL_INTRINSIC_1" + ): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedGeneratorExit: + # If it doesn't catch, we just return None, as per the text above + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + try: + # Raise RuntimeError if the generator yields any other value + if self.next_variable(tx): + raise_observed_exception(RuntimeError, tx) + except ObservedGeneratorExit: + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedUserStopIteration: + # In Python 3.13+, one can capture GeneratorExit and return a value + # See test_generator.py::test_close_capture_GeneratorExit_return + # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26 + # https://github.com/python/cpython/pull/104771 + assert tracer.symbolic_result is not None + return tracer.symbolic_result super().call_method(tx, name, args, kwargs) From 53ab82d8f5c1f3f12b956cdd8745bf17262ef97e Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:21 -0300 Subject: [PATCH 16/28] Implement `generator.throw(exception)` (#144424) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144424 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421, #144422, #144423 --- test/dynamo/test_generator.py | 394 ++++++++++++++++++++++++++- torch/_dynamo/exc.py | 8 + torch/_dynamo/variables/functions.py | 100 +++++++ 3 files changed, 501 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 8aa86eb4bb3..f0e38d5f4a6 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -7,7 +7,7 @@ from collections import OrderedDict import torch import torch._dynamo.test_case import torch._dynamo.testing -from torch._dynamo.exc import Unsupported +from torch._dynamo.exc import InternalTorchDynamoError, Unsupported from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -977,6 +977,255 @@ class TestGeneratorClose(GeneratorTestsBase): self.assertEqual(z, 2) +class TestGeneratorThrow(GeneratorTestsBase): + def test_throw(self): + def whoo(t): + try: + yield t.sin() + except RuntimeError: + yield t.cos() + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.cos()) + + @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE") + def test_throw_with_finally(self): + z = 0 + + def whoo(): + nonlocal z + z = 0 + try: + try: + yield 1 + except ValueError: + yield 2 + finally: + z += 2 + except ValueError: + z += 33 + yield 4 + finally: + z += 1 + z += 10 + + def f(x): + gen = whoo() + next(gen) + gen.throw(ValueError) + return x.sin() + + self._compile_check(f) + self.assertEqual(z, 3) + + def test_throw_without_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + z += 10 + except RuntimeError: + z += 100 + yield t.cos() + z += 1_000 + z += 10_000 + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.cos()) + self.assertEqual(z, 101) + + def test_throw_three_arguments(self): + def whoo(t): + try: + yield t.sin() + except ValueError: + yield t.cos() + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(ValueError, "Error", None) + return a + b + + t = torch.randn(2) + with self.assertRaises(InternalTorchDynamoError): + fn(t) + + def test_throw_no_yield_after_throw(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + finally: + z += 100 + + def fn(t): + gen = whoo(t) + a = next(gen) + try: + gen.throw(ValueError) + except StopIteration: + return a + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(z, 111) + self.assertEqual(y, t.sin()) + + def test_throw_not_catch(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + yield t.cos() + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + with self.assertRaises(RuntimeError): + fn(t) + + def test_throw_raise_difference_exc(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError as e: + z += 10 + raise RuntimeError from e + finally: + z += 100 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(ValueError) + return a + b + + t = torch.randn(2) + with self.assertRaises(RuntimeError): + fn(t) + + def test_throw_yield_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except RuntimeError: + z += 10 + yield t.cos() + finally: + z += 100 + yield t.tan() # RuntimeError: generator ignored GeneratorExit + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + with self.assertRaises(Unsupported): + fn(t) + + @unittest.skipIf(sys.version_info < (3, 11), "Missing RERAISE") + def test_throw_try_except_finally(self): + z = 0 + + def whoo(t): + nonlocal z + z = 0 + try: + z += 1 + yield t.sin() + except ValueError: + z += 10 + yield t.cos() + except RuntimeError: + z += 100 + yield t.tan() + finally: + z += 1000 + z += 10_000 + + def fn(t): + gen = whoo(t) + a = next(gen) + b = gen.throw(RuntimeError) + return a + b + + t = torch.randn(2) + y = self._compile_check(fn, (t,)) + self.assertEqual(y, t.sin() + t.tan()) + self.assertEqual(z, 1 + 100 + 1000) + + def test_exception_context_with_yield(self): + def f(): + yield + + def fn(t): + gen = f() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError: + z = 1 + except Exception as e: + raise AssertionError from e + assert z == 1 + return t.sin() + + self._compile_check(fn) + + class GeneratorCloseCPythonTests(GeneratorTestsBase): # Taken from commit # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 @@ -1118,6 +1367,149 @@ class GeneratorCloseCPythonTests(GeneratorTestsBase): fn(t) +class GeneratorThrowCpythonTests(GeneratorTestsBase): + # Taken from commit + # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 + # changed the tests a little bit to run them inside dynamo + # + replaced all self.assert* calls to plain assert statements + + @unittest.expectedFailure + def test_exception_context_with_yield(self): + def f(): + try: + raise KeyError("a") + except Exception: + yield + + def fn(t): + gen = f() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError as e: + context = e.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + @unittest.expectedFailure + def test_exception_context_with_yield_inside_generator(self): + # Check that the context is also available from inside the generator + # with yield, as opposed to outside. + def f(): + z = 0 + try: + raise KeyError("a") + except Exception: + try: + yield + except Exception as exc: + z = 1 + assert type(exc) == ValueError + context = exc.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + yield "b" + finally: + assert z == 1 + + def fn(t): + gen = f() + gen.send(None) + actual = gen.throw(ValueError) + # This ensures that the assertions inside were executed. + assert actual == "b" + return t.sin() + + self._compile_check(fn) + + @unittest.expectedFailure + def test_exception_context_with_yield_from(self): + def f(): + yield + + def g(): + try: + raise KeyError("a") + except Exception: + yield from f() + + def fn(t): + gen = g() + gen.send(None) + try: + gen.throw(ValueError) + except ValueError as e: + context = e.__context__ + assert (type(context), context.args) == (KeyError, ("a",)) + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + @unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW") + @unittest.expectedFailure + def test_exception_context_with_yield_from_with_context_cycle(self): + # Check trying to create an exception context cycle: + # https://bugs.python.org/issue40696 + has_cycle = None + + def f(): + yield + + def g(exc): + nonlocal has_cycle + try: + raise exc + except Exception: + try: + yield from f() + except Exception as exc: + has_cycle = exc is exc.__context__ + yield + + def fn(t): + exc = KeyError("a") + gen = g(exc) + gen.send(None) + gen.throw(exc) + # This also distinguishes from the initial has_cycle=None. + assert has_cycle is False + return t.sin() + + self._compile_check(fn) + + def test_throw_after_none_exc_type(self): + def g(): + try: + raise KeyError + except KeyError: + pass + + try: + yield + except Exception: + raise RuntimeError # noqa: B904 + + def fn(t): + gen = g() + gen.send(None) + z = 0 + try: + gen.throw(ValueError) + except RuntimeError: + z += 1 + except Exception: + raise AssertionError # noqa: B904 + assert z == 1 + return t.sin() + + self._compile_check(fn) + + class GeneratorCPythonTests(GeneratorTestsBase): # Taken from commit # https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10 diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index a2f5f938bdc..aaa51f0a53c 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -309,6 +309,14 @@ observed_exception_map = { } +def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]: + if exc_type not in observed_exception_map: + observed_exception_map[exc_type] = type( + f"Observed{exc_type.__name__}Error", (ObservedException,), {} + ) + return observed_exception_map[exc_type] + + def raise_observed_exception( exc_type: type[Exception], tx: InstructionTranslatorBase, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 293ba87a516..d51be98f69f 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -16,7 +16,9 @@ import torch from .. import polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..exc import ( + get_dynamo_observed_exception, handle_observed_exception, + IncorrectUsage, InfiniteGeneratorError, ObservedException, ObservedGeneratorExit, @@ -604,6 +606,104 @@ class LocalGeneratorObjectVariable(VariableTracker): # https://github.com/python/cpython/pull/104771 assert tracer.symbolic_result is not None return tracer.symbolic_result + elif name == "throw": + # * Raises an exception at the point where the generator was paused, and + # returns the next value yielded by the generator. + # * If the generator exits without yielding, raise StopIteration + # * If the generator function does not catch the passed-in exception, + # or raises a different exception, then that exception propagates to the caller. + + if len(args) > 1: + raise IncorrectUsage( + "the (type, exc, tb) signature of throw() is deprecated, " + "use the single-arg signature instead." + ) + + # Setup the exception table and jump target in case of try...finally + tracer = self._get_inline_tracer(tx) + try: + self._setup_exception(tx, args[0]) + except ObservedException: + # propagate the exception back to the parent caller + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + raise + + retval = self.next_variable(tx) + + # The exception raised before is still active. We need to check the exception + # table one more time to find the next target. But why? Let’s walk + # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M + # + # z = 0 + # def whoo(): + # global z + # z = 0 + # try: + # yield 1 + # except ValueError: + # yield 2 + # finally: + # z += 1 + # z += 10 + # + # gen = whoo() + # next(gen) + # gen.throw(ValueError) + # print('z', z) -> z = 1 + # + # ... + # >> 58 PUSH_EXC_INFO + # + # 8 60 LOAD_GLOBAL 2 (ValueError) + # 70 CHECK_EXC_MATCH + # 72 POP_JUMP_IF_FALSE 7 (to 88) + # 74 POP_TOP + # + # 9 76 LOAD_CONST 3 (2) + # 78 YIELD_VALUE 3 <------ ValueError is still active here + # 80 RESUME 1 + # 82 POP_TOP + # 84 POP_EXCEPT + # 86 jump_backward 34 (to 20) + # ... + # + # ExceptionTable: + # 4 to 8 -> 124 [0] lasti + # 12 to 18 -> 58 [0] + # 20 to 56 -> 124 [0] lasti + # 58 to 82 -> 90 [1] lasti <------ move to 90 + # 84 to 86 -> 96 [0] + # 88 to 88 -> 90 [1] lasti + # 90 to 94 -> 96 [0] + # 96 to 116 -> 118 [1] lasti + # 118 to 122 -> 124 [0] lasti + # + # In this scenario, a generator can yield after `throw()` is called. Even + # after the exception is raised a few lines above, it remains active + # within the `78 YIELD_VALUE` instruction. When the generator resumes + # after the second yield on instruction `80 RESUME`, we cannot simply + # return the control flow to the next instruction. Instead, one must + # check the exception table (or equivalent) to find the next target + # In this case, it says the instruction pointer must be moved to 90. + # + # Without this step, if we let the trace proceed to the next + # instruction, it would follow the control flow where the exception + # raised by `throw()` was handled and swallowed, potentially leading + # to incorrect behavior. + exc_type = type("__InternalThrowException", (Exception,), {}) + + try: + self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self.next_variable(tx) + except get_dynamo_observed_exception(exc_type): + # We should get back the exception raised before. + pass + except ObservedException: + # Propagate anything else back to the parent caller + tx.exn_vt_stack.extend(tracer.exn_vt_stack) + else: + raise_observed_exception(RuntimeError, tracer) + return retval super().call_method(tx, name, args, kwargs) From 68cfd36c11137d283600ad95102fd22f025417a0 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:21 -0300 Subject: [PATCH 17/28] Add `CLEANUP_THROW` bytecode (#144420) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144420 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421, #144422, #144423, #144424 --- test/dynamo/test_generator.py | 38 ++++++++++++++++++++++++++----- torch/_dynamo/symbolic_convert.py | 9 ++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index f0e38d5f4a6..dabda71601c 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -593,6 +593,38 @@ class GraphModule(torch.nn.Module): self.assertEqual(i, 3) self.assertEqual(y, [(0, t), (1, t + 1), (2, t + 2)]) + @unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW") + @unittest.expectedFailure + def test_cleanup_throw(self): + def nested_generator(): + try: + yield 1 + yield 2 + except StopIteration: + return 123 # noqa: B901 + + def outer_generator(): + yield from nested_generator() + yield 3 + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + gen = outer_generator() + next(gen) # Start the outer generator and enter the nested generato + + i = 0 + try: + # Force an exception while the generator is running + i = gen.throw(StopIteration("stop")) + except RuntimeError: + pass + return (i, t.sin()) + + t = torch.randn(2) + i, y = self._compile_check(fn, args=(t,)) + self.assertEqual(i, 3) + self.assertEqual(y, t.sin()) + def test_iter(self): def whoo(): i = 0 @@ -670,8 +702,6 @@ class TestGeneratorClose(GeneratorTestsBase): y = fn(t) self.assertEqual(y, t.sin()) - @unittest.expectedFailure - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_close_subgen(self): z = 0 @@ -844,8 +874,6 @@ class TestGeneratorClose(GeneratorTestsBase): with self.assertRaises(exc): fn(t) - @unittest.expectedFailure - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_close_with_subgen(self): L = [] z = 0 @@ -1450,8 +1478,6 @@ class GeneratorThrowCpythonTests(GeneratorTestsBase): self._compile_check(fn) - @unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW") - @unittest.expectedFailure def test_exception_context_with_yield_from_with_context_cycle(self): # Check trying to create an exception context cycle: # https://bugs.python.org/issue40696 diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index fd39e76032e..10532517862 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1511,6 +1511,15 @@ class InstructionTranslatorBase( self._raise_exception_variable(inst) unimplemented("raise ... from ...") + def CLEANUP_THROW(self, inst): + # https://github.com/python/cpython/pull/96010 + tos = self.stack[-1] + assert isinstance(tos, ExceptionVariable) + if tos.exc_type is StopIteration: + unimplemented("CLEANUP_THROW with StopIteration") + else: + self.RERAISE(inst) + def RERAISE(self, inst): if sys.version_info >= (3, 11): # RERAISE is currently supported in a narrow case of `raise ... from None` From 580a305681808ae364aa95b2eb425ad41852ef6b Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:22 -0300 Subject: [PATCH 18/28] Raise MutationError if there are side effects when returning generator (#145223) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145223 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421, #144422, #144423, #144424, #144420 --- test/dynamo/test_generator.py | 244 +++++++++++++++++++++++++-- torch/_dynamo/exc.py | 5 + torch/_dynamo/output_graph.py | 3 + torch/_dynamo/side_effects.py | 36 +++- torch/_dynamo/symbolic_convert.py | 28 +++ torch/_dynamo/variables/functions.py | 21 ++- 6 files changed, 313 insertions(+), 24 deletions(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index dabda71601c..764db540575 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -9,6 +9,7 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.exc import InternalTorchDynamoError, Unsupported from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm +from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -156,6 +157,182 @@ class GraphModule(torch.nn.Module): """, ) + def test_reconstruct_generator_with_local_var_mutation(self): + def whoo(t): + x = 0 + yield t.sin() + x + x += 1 + yield t.cos() + x + x += 1 + yield t.tan() + x + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + next(gen) + return t.sin(), gen + + t = torch.randn(2) + y, g = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(list(g), [t.cos() + 1, t.tan() + 2]) + + def test_reconstruct_generator_with_dict_mutation(self): + counters.clear() + + def whoo(t, d): + d[2] = t + yield t.sin() + yield t.cos() + d[3] = t + 1 + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, d): + gen = whoo(t, d) + next(gen) + return t.sin(), whoo(t, d) + + t = torch.randn(2) + d = {1: t} + fn(t, d) + self.assertEqual(len(counters["unimplemented"]), 1) + self.assertEqual( + dict(counters["unimplemented"]), + { + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications.": 1 + }, + ) + + def test_reconstruct_generator_with_dict_mutation_before(self): + def whoo(t, d): + d[2] = t + yield t.sin() + yield t.cos() + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, d): + gen = whoo(t, d) + next(gen) + return t.sin(), gen + + t = torch.randn(2) + d = {1: t} + y, g = fn(t, d) + self.assertEqual(y, t.sin()) + self.assertEqual(list(g), [t.cos(), t.tan()]) + self.assertEqual(d, {1: t, 2: t}) + + def test_reconstruct_generator_with_object_mutation(self): + class Counter: + def __init__(self): + self.x = 0 + + def incr(self): + self.x += 1 + + def whoo(t, c): + c.incr() + yield t.sin() + yield t.cos() + c.incr() + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, c): + gen = whoo(t, c) + next(gen) + return t.sin(), gen + + t = torch.randn(2) + c = Counter() + fn(t, c) + self.assertEqual(len(counters["unimplemented"]), 1) + self.assertEqual( + dict(counters["unimplemented"]), + { + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications.": 1 + }, + ) + + def test_reconstruct_generator_with_object_mutation_before(self): + class Counter: + def __init__(self): + self.x = 0 + + def incr(self): + self.x += 1 + + def whoo(t, c): + c.incr() + yield t.sin() + yield t.cos() + yield t.tan() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t, c): + gen = whoo(t, c) + next(gen) + # We should be able to reconstruct the generator as there's no object + # mutation after the first yield + return t.sin(), gen + + t = torch.randn(2) + c = Counter() + y, g = fn(t, c) + self.assertEqual(c.x, 1) + self.assertEqual(y, t.sin()) + self.assertEqual(list(g), [t.cos(), t.tan()]) + + def test_graph_break_and_reconstruct_generator(self): + def whoo(t): + yield t.sin() + yield t.cos() + yield t.tan() + + def g(t): + torch._dynamo.graph_break() + + @torch.compile(backend="eager", fullgraph=False) + def fn(t): + gen = whoo(t) + next(gen) + g(t) + return t.sin(), list(gen) + + t = torch.randn(2) + y, gen = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(list(gen), [t.cos(), t.tan()]) + + def test_graph_break_in_generator_while_reconstructing(self): + counters.clear() + + def whoo(): + yield 1 + torch._dynamo.graph_break() + yield 2 + + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) + def fn(t): + gen = whoo() + s = next(gen) + torch._dynamo.graph_break() + s += next(gen) + return t + s + + t = torch.randn(2) + y = fn(t) + self.assertEqual(y, t + 3) + self.assertEqual(len(eager.graphs), 0) + def test_generator_as_argument(self): # The inline tracer needs to be kept in sync if an already advanced generator # is given to a compiled function. @@ -404,6 +581,22 @@ class GraphModule(torch.nn.Module): with self.assertRaises(StopIteration): next(gen) + @unittest.expectedFailure + def test_reconstruct_generator_tensor_mutation(self): + def whoo(t): + yield t.sin_() + yield t.cos_() + + def fn(t): + gen = whoo(t) + return gen + + with self.assertRaisesRegex( + Unsupported, + "Cannot reconstruct a generator with variable mutations", + ): + self._compile_check(fn) + def test_subgenerator(self): def subgen(t): yield t + 1 @@ -509,8 +702,8 @@ class GraphModule(torch.nn.Module): got = torch.compile(backend="eager", fullgraph=False)(fn)(t) self.assertEqual(expected, got) - @unittest.expectedFailure def test_generator_with_side_effects(self): + counters.clear() i = 0 def whoo(t): @@ -519,14 +712,22 @@ class GraphModule(torch.nn.Module): i += 1 yield t + j + @torch.compile(backend="eager", fullgraph=True) def fn(t): return whoo(t), t.sin() t = torch.randn(2) - with self.assertRaises(Unsupported): - fn(t) + fn(t) + self.assertEqual(len(counters["unimplemented"]), 1) + self.assertEqual( + dict(counters["unimplemented"]), + { + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications.": 1 + }, + ) - @unittest.expectedFailure def test_subgenerator_with_side_effects(self): i = 0 @@ -547,13 +748,20 @@ class GraphModule(torch.nn.Module): i += 1 yield t + 4 + @torch.compile(backend="eager", fullgraph=True) def fn(t): return whoo(t), t.sin() - with self.assertRaises(Unsupported): - self._compile_check(fn) + t = torch.randn(2) + gen, y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(len(list(gen)), 5) + self.assertTrue( + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications." in dict(counters["unimplemented"]) + ) - @unittest.expectedFailure def test_generator_with_side_effects_graph_break(self): i = 0 @@ -567,11 +775,18 @@ class GraphModule(torch.nn.Module): def fn(t): gen = whoo(t) torch._dynamo.graph_break() - return list(zip(range(3), gen)) + next(gen) + return gen, t.sin() t = torch.randn(2) - with self.assertRaises(Unsupported): - fn(t) + gen, y = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(len(list(gen)), 4) + self.assertTrue( + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications." in dict(counters["unimplemented"]) + ) def test_generator_with_side_effects_graph_break_2(self): i = 0 @@ -583,15 +798,16 @@ class GraphModule(torch.nn.Module): yield t + j torch._dynamo.graph_break() - @torch.compile(backend="eager", fullgraph=False) + eager = EagerAndRecordGraphs() + + @torch.compile(backend=eager, fullgraph=False) def fn(t): gen = whoo(t) return list(zip(range(3), gen)) t = torch.randn(2) - y = fn(t) - self.assertEqual(i, 3) - self.assertEqual(y, [(0, t), (1, t + 1), (2, t + 2)]) + fn(t) + self.assertEqual(len(eager.graphs), 0) @unittest.skipIf(sys.version_info < (3, 12), "Test CLEANUP_THROW") @unittest.expectedFailure diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index aaa51f0a53c..7a524d5017f 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -168,6 +168,11 @@ class InfiniteGeneratorError(Unsupported): super().__init__(msg) +class SideEffectsError(Unsupported): + def __init__(self, msg: str) -> None: + super().__init__(msg) + + class CondOpArgsMismatchError(ArgsMismatchError): """ Internal error from cond() due to arguments mismatch. diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 35219afb9f0..818ef871988 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1936,6 +1936,9 @@ class SubgraphTracer(fx.Tracer): # backward recomputation of the checkpoint region doesn't affect its correctness. self.allow_side_effects_under_checkpoint = False + # True if this tracer is currently tracing (reconstructing) into a Python generator + self.is_reconstructing_generator = False + self.debug_level: int = parent.debug_level + 1 if parent is not None else 0 self._cur_code = None diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index c24f4d821c3..2cdf6a0cc41 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -7,7 +7,7 @@ import warnings import weakref from collections.abc import MutableMapping from types import CellType -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING import torch.nn @@ -19,7 +19,7 @@ from .bytecode_transformation import ( create_instruction, ) from .codegen import PyCodegen -from .exc import unimplemented +from .exc import SideEffectsError, unimplemented from .source import GlobalSource, LocalCellSource, LocalSource, Source from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new, tuple_new from .variables.base import ( @@ -34,6 +34,10 @@ from .variables.base import ( from .variables.user_defined import FrozenDataClassVariable +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + def _manual_dict_setitem(dict_from, dict_to, mro_index): # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have # to be careful because we don't want to trigger the user defined object @@ -134,6 +138,14 @@ class SideEffects: and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint ) + def is_reconstructing_generator(self): + output_graph = self.output_graph_weakref() + + return ( + output_graph + and output_graph.current_tx.output.current_tracer.is_reconstructing_generator + ) + def check_allowed_side_effect(self, item): from torch._dynamo.variables.misc import AutogradFunctionContextVariable @@ -143,6 +155,14 @@ class SideEffects: return True if self.should_allow_side_effects_under_checkpoint(): return True + if self.is_reconstructing_generator(): + # This is missing the case where one mutates a tensor. See + # test_generator.py::test_reconstruct_generator_tensor_mutation + raise SideEffectsError( + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications." + ) if not is_side_effect_safe(item.mutation_type): unimplemented( "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" @@ -842,7 +862,7 @@ class SideEffects: @contextlib.contextmanager -def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: ignore[name-defined] # noqa: F821 +def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): assert tx.output.current_tracer.under_activation_checkpoint orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint try: @@ -850,3 +870,13 @@ def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: i yield finally: tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val + + +@contextlib.contextmanager +def disallow_side_effects_in_generator(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.is_reconstructing_generator + try: + tx.output.current_tracer.is_reconstructing_generator = True + yield + finally: + tx.output.current_tracer.is_reconstructing_generator = orig_val diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 10532517862..9a8edf94478 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -291,6 +291,34 @@ def _step_logger(): return torchdynamo_logging.get_step_logger(log) +@contextlib.contextmanager +def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"): + # When reconstructing a generator after a graph break, we advance it until + # it is fully exhausted. This process adds new entries to the speculation + # log that were not previously observed. Without temporarily clearing the + # speculation log, this could lead to a divergence error. + + entries = tx.speculation_log.entries + index = tx.speculation_log.index + try: + tx.speculation_log.entries = [] + tx.speculation_log.index = 0 + yield + finally: + tx.speculation_log.entries = entries + tx.speculation_log.index = index + + +@contextlib.contextmanager +def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"): + try: + tmp = tx.output.should_exit + tx.output.should_exit = False + yield + finally: + tx.output.should_exit = tmp + + @dataclasses.dataclass class BlockStackEntry: # Current instruction that pushes something to block_stack diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index d51be98f69f..45e04bbb80c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -428,18 +428,23 @@ class LocalGeneratorObjectVariable(VariableTracker): __repr__ = __str__ def reconstruct(self, codegen): - from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.side_effects import disallow_side_effects_in_generator + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + save_and_restart_speculation_log, + temporarely_allow_writes_to_output_graph, + ) tx = InstructionTranslator.current_tx() - tracer = self._get_inline_tracer(tx) - try: - prev = tx.output.should_exit - tx.output.should_exit = False + save = save_and_restart_speculation_log(tx) + disallow = disallow_side_effects_in_generator(tx) + temp = temporarely_allow_writes_to_output_graph(tx) + + with save, disallow, temp: + tracer = self._get_inline_tracer(tx) if not tracer.generator_exhausted: self.remaining_items = self.force_unpack_var_sequence(tx) variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) - finally: - tx.output.should_exit = prev def bind_args(self, tx, args, kwargs): return self.fn.bind_args(tx, args, kwargs) @@ -480,6 +485,8 @@ class LocalGeneratorObjectVariable(VariableTracker): except Unsupported as e: torch._C._dynamo.eval_frame.skip_code(self.get_code()) raise SkipFrame from e + finally: + counters["unimplemented"] |= counters["inline_call"] def has_unpack_var_sequence(self, tx): return False From 6a9a02acbe34a9d810c8bf56c865b9d0687a3051 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 7 Feb 2025 14:55:22 -0300 Subject: [PATCH 19/28] Set `enable_faithful_generator_behavior` flag to True (#142513) Pull Request resolved: https://github.com/pytorch/pytorch/pull/142513 Approved by: https://github.com/zou3519 ghstack dependencies: #141055, #144421, #144422, #144423, #144424, #144420, #145223 --- test/dynamo/test_ctx_manager.py | 10 +++++----- test/dynamo/test_misc.py | 15 --------------- test/dynamo/test_repros.py | 4 ++-- ..._eval_mode_swap_True_set_grad_True_cpu_float32 | 0 ...train_mode_swap_True_set_grad_True_cpu_float32 | 0 ..._eval_mode_swap_True_set_grad_True_cpu_float32 | 0 ...train_mode_swap_True_set_grad_True_cpu_float32 | 0 ..._eval_mode_swap_True_set_grad_True_cpu_float32 | 0 ...train_mode_swap_True_set_grad_True_cpu_float32 | 0 ...n_Bilinear_swap_True_set_grad_True_cpu_float32 | 0 ..._nn_Conv1d_swap_True_set_grad_True_cpu_float32 | 0 ..._nn_Conv2d_swap_True_set_grad_True_cpu_float32 | 0 ..._nn_Conv3d_swap_True_set_grad_True_cpu_float32 | 0 ...ranspose1d_swap_True_set_grad_True_cpu_float32 | 0 ...ranspose2d_swap_True_set_grad_True_cpu_float32 | 0 ...ranspose3d_swap_True_set_grad_True_cpu_float32 | 0 ...nn_GRUCell_swap_True_set_grad_True_cpu_float32 | 0 ..._GroupNorm_swap_True_set_grad_True_cpu_float32 | 0 ...n_LSTMCell_swap_True_set_grad_True_cpu_float32 | 0 ..._LayerNorm_swap_True_set_grad_True_cpu_float32 | 0 ..._nn_Linear_swap_True_set_grad_True_cpu_float32 | 0 ...nn_RNNCell_swap_True_set_grad_True_cpu_float32 | 0 ...eval_mode_swap_True_set_grad_True_cuda_float32 | 0 ...rain_mode_swap_True_set_grad_True_cuda_float32 | 0 ...eval_mode_swap_True_set_grad_True_cuda_float32 | 0 ...rain_mode_swap_True_set_grad_True_cuda_float32 | 0 ...eval_mode_swap_True_set_grad_True_cuda_float32 | 0 ...rain_mode_swap_True_set_grad_True_cuda_float32 | 0 ..._Bilinear_swap_True_set_grad_True_cuda_float32 | 0 ...nn_Conv1d_swap_True_set_grad_True_cuda_float32 | 0 ...nn_Conv2d_swap_True_set_grad_True_cuda_float32 | 0 ...nn_Conv3d_swap_True_set_grad_True_cuda_float32 | 0 ...anspose1d_swap_True_set_grad_True_cuda_float32 | 0 ...anspose2d_swap_True_set_grad_True_cuda_float32 | 0 ...anspose3d_swap_True_set_grad_True_cuda_float32 | 0 ...n_GRUCell_swap_True_set_grad_True_cuda_float32 | 0 ...GroupNorm_swap_True_set_grad_True_cuda_float32 | 0 ..._LSTMCell_swap_True_set_grad_True_cuda_float32 | 0 ...LayerNorm_swap_True_set_grad_True_cuda_float32 | 0 ...nn_Linear_swap_True_set_grad_True_cuda_float32 | 0 ...n_RNNCell_swap_True_set_grad_True_cuda_float32 | 0 torch/_dynamo/config.py | 2 +- 42 files changed, 8 insertions(+), 23 deletions(-) delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Bilinear_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv1d_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv2d_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv3d_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_GRUCell_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_GroupNorm_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_LSTMCell_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_LayerNorm_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Linear_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_RNNCell_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Bilinear_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv1d_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv2d_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv3d_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GRUCell_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GroupNorm_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LSTMCell_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LayerNorm_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Linear_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RNNCell_swap_True_set_grad_True_cuda_float32 diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index bf975827cf6..c874a578b4e 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -2237,7 +2237,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_graph_break_before_and_after___enter__(self): @contextlib.contextmanager @@ -2263,7 +2263,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_graph_break_before___enter___and_disable___exit__(self): @contextlib.contextmanager @@ -2293,7 +2293,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_disable___enter__(self): def h(x): @@ -2574,7 +2574,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False)(fn)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 1) + self.assertEqual(len(eager.graphs), 0) def test_dynamo_disable_ctx(self): @contextlib.contextmanager @@ -2624,7 +2624,7 @@ class GraphModule(torch.nn.Module): eager = EagerAndRecordGraphs() out = torch.compile(backend=eager, fullgraph=False, dynamic=False)(f)(x) self.assertEqual(expected, out) - self.assertEqual(len(eager.graphs), 3) + self.assertEqual(len(eager.graphs), 2) @parametrize("name", ("suppress", "stdout", "stderr")) def test_contextlib_suppress(self, name): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 746f25a2e8a..239412b8370 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9579,21 +9579,6 @@ def ___make_guard_fn(): ): compiled_fn(x) - # FIXME(XuehaiPan): do not inline infinite generator if it does not raise errors in eager mode - def fn(x): - def gen(): - while True: - yield x - - return list(zip(range(10), gen())) - - x = torch.randn([0, 1, 2, 3, 4, 5]) - compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, "infinite generator" - ): - compiled_fn(x) - def test_itertools_islice(self): counters.clear() diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index eae6ec46477..dc3d58f5b0d 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1418,9 +1418,9 @@ class ReproTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(opt_model(a, b, c, d), correct)) if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline(cnt.frame_count, """4""") + self.assertExpectedInline(cnt.frame_count, """2""") else: - self.assertExpectedInline(cnt.frame_count, """5""") + self.assertExpectedInline(cnt.frame_count, """3""") def test_hf_model_output(self): ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Bilinear_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Bilinear_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv1d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv1d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv2d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv2d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv3d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Conv3d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GRUCell_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GRUCell_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GroupNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_GroupNorm_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LSTMCell_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LSTMCell_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LayerNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_LayerNorm_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Linear_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Linear_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RNNCell_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RNNCell_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_eval_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm2d_train_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_eval_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_BatchNorm3d_train_mode_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Bilinear_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Bilinear_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv1d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv1d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv2d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv2d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv3d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Conv3d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose1d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose2d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_ConvTranspose3d_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GRUCell_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GRUCell_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GroupNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_GroupNorm_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LSTMCell_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LSTMCell_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LayerNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_LayerNorm_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Linear_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Linear_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RNNCell_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RNNCell_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index c3b57afc881..f033a282d27 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -419,7 +419,7 @@ enable_trace_contextlib = True # Enable tracing generator functions lazily. If False, Dynamo will exhaust # generators upon first execution. And if True, the generator will be accessed lazily -enable_faithful_generator_behavior = False +enable_faithful_generator_behavior = True # Inline inbuilt nn modules inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated] From 0e83e7d56e60e05de5f6ccbb95aee498c7a739dc Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Sat, 8 Feb 2025 23:40:23 +0000 Subject: [PATCH 20/28] [EZ] Add logic to build Metal shader with debug info (#146768) By appending `-frecord-sources -gline-tables-only` to the compilation command Helpful when debugging shaders compiled into libtorch Test plan: Run `python ../tools/build_with_debinfo.py ../aten/src/ATen/native/mps/kernels/UpSample.metal ../aten/src/ATen/native/mps/operations/UpSample.mm` And then run following to capture shader and check that it contains debug info ```python import torch import os os.environ["MTL_CAPTURE_ENABLED"]="1" inp = torch.rand(size=(6, 3, 10, 20), device="mps", dtype=torch.float32) with torch.mps.profiler.metal_capture("bilinear2d"): out = torch.nn.functional.interpolate(x, scale_factor=(1.7,0.9), mode="bilinear") ``` image Pull Request resolved: https://github.com/pytorch/pytorch/pull/146768 Approved by: https://github.com/dcci --- tools/build_with_debinfo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index 26c054bf2a0..0c9553b963e 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -78,6 +78,9 @@ def create_build_plan() -> list[tuple[str, str]]: if line.startswith(": &&") and line.endswith("&& :"): line = line[4:-4] line = line.replace("-O2", "-g").replace("-O3", "-g") + # Build Metal shaders with debug infomation + if "xcrun metal " in line and "-frecord-sources" not in line: + line += " -frecord-sources -gline-tables-only" try: name = line.split("-o ", 1)[1].split(" ")[0] rc.append((name, line)) From 91c4bf39d39f0607833b2a226e48a9ab7262c906 Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Sun, 9 Feb 2025 05:11:17 +0000 Subject: [PATCH 21/28] [mps] Add a shader for spherical_bessel_j0. (#146771) In preparation for adding the operation to inductor/eager. Adapted from the CUDA version of the shader. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146771 Approved by: https://github.com/malfet --- c10/metal/special_math.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 8bcb1f7a53e..04fd7eee18f 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) { return float2(re, im) / a2; } +template +inline T spherical_bessel_j0(T x) { + if (::metal::isinf(x)) + return T(0.0); + T x2 = x * x; + T k1 = static_cast(-1.0); + T k2 = static_cast(1.0); + + if (::metal::abs(x) < T(0.5)) { + return T(1.0) + + x2 * + (k1 / T(6.0) + + x2 * + (k2 / T(120.0) + + x2 * + (k1 / T(5040.0) + + x2 * + (k2 / T(362880.0) + + x2 * + (k1 / T(39916800.0) + + x2 * (k2 / T(6227020800.0))))))); + } + + return ::metal::sin(x) / x; +} + } // namespace metal } // namespace c10 From b133907d0ab2fe5250681ca3407b8c16fb74fdd5 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 7 Feb 2025 20:48:51 -0800 Subject: [PATCH 22/28] Update strided test to float32 (#146748) Fixes #146377 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146748 Approved by: https://github.com/BoyuanFeng, https://github.com/leijurv --- test/inductor/test_flex_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 8b4382061b0..99440593c2b 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform def test_strided_backwards(self): shape = (1, 2, 4096, 64) - Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) + Q = torch.randn(shape, requires_grad=True, device="cuda") + K = torch.randn(shape, requires_grad=True, device="cuda") + V = torch.randn(shape, requires_grad=True, device="cuda") func = torch.compile(flex_attention, dynamic=True, fullgraph=True) K_sliced = K[:, :, :-128] From 2a55311773bfb9e569aa672ac3322d21abc1af32 Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Sun, 9 Feb 2025 20:09:34 +0000 Subject: [PATCH 23/28] [cuda] Simplify the sinc function a bit. (#146774) `else` after `return` can be removed & the indentation can be reduced, for readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146774 Approved by: https://github.com/malfet --- aten/src/ATen/native/cuda/Math.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index b99e9d0c94d..2fe8f5dd2e3 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify( T sinc(T a) { if (a == T(0)) { return T(1); - } else { - constexpr T pi = T(3.14159265358979323846L); - T product = pi * a; - return std::sin(product) / product; } + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; } ); // sinc_string From 298226f358239048ccf93ab80d045d9137cb2bb3 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 7 Feb 2025 10:33:32 -0800 Subject: [PATCH 24/28] [dynamo] check for incompatible configs (#146513) internal: https://fb.workplace.com/groups/1075192433118967/permalink/1599802033991335/ Assuming flags don't change during compilation, we shouldn't allow incompatible configs to be set at torch.compile wrap time. Not in this PR: For flags that need to change during compilation, we'd have to be strict about where they can be used in the compile lifecycle Pull Request resolved: https://github.com/pytorch/pytorch/pull/146513 Approved by: https://github.com/williamwen42 Co-authored-by: Gabriel Ferns --- test/dynamo/test_repros.py | 21 +++++++++++++++++++++ torch/_dynamo/config.py | 2 ++ torch/_dynamo/eval_frame.py | 8 ++++++++ 3 files changed, 31 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index dc3d58f5b0d..eb1a8d2d6ca 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6510,6 +6510,27 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): ).sum() self.assertEqual(actual, expected) + def test_incompatible_configs(self): + with torch._dynamo.config.patch( + suppress_errors=False, fail_on_recompile_limit_hit=False + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=True, fail_on_recompile_limit_hit=False + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=False, fail_on_recompile_limit_hit=True + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=True, fail_on_recompile_limit_hit=True + ), self.assertRaises(AssertionError): + torch.compile(lambda: None) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index f033a282d27..84ed3be5a70 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -52,6 +52,7 @@ skip_code_recursive_on_recompile_limit_hit = True # raise a hard error if cache limit is hit. If you are on a model where you # know you've sized the cache correctly, this can help detect problems when # you regress guards/specialization. This works best when recompile_limit = 1. +# This flag is incompatible with: suppress_errors. # [@compile_ignored: runtime_behaviour] fail_on_recompile_limit_hit = False @@ -164,6 +165,7 @@ traceable_tensor_subclasses: set[type[Any]] = set() # This is a good way to get your model to work one way or another, but you may # lose optimization opportunities this way. Devs, if your benchmark model is failing # this way, you should figure out why instead of suppressing it. +# This flag is incompatible with: fail_on_recompile_limit_hit. suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) # Record and write an execution record of the current frame to a file diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 045cd350b60..789ed41d3a2 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -833,6 +833,13 @@ def is_inductor_supported(): return False +def check_for_incompatible_configs(): + # Some of the configs should be mutually exclusive + assert not ( + config.suppress_errors and config.fail_on_recompile_limit_hit + ), "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time." + + def optimize(*args, **kwargs): def rebuild_ctx(): ca_kwargs_override = config.compiled_autograd_kwargs_override @@ -885,6 +892,7 @@ def _optimize( ... """ check_if_dynamo_supported() + check_for_incompatible_configs() # Note: The hooks object could be global instead of passed around, *however* that would make # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same From e8304f08fedc802a90f9361c30861f8c5aab946e Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Mon, 10 Feb 2025 01:19:30 +0000 Subject: [PATCH 25/28] Fix torch.take_along_dim param type and default description (#146474) ## Changes - Change type description to `LongTensor`, consistent with [`torch.take`](https://pytorch.org/docs/stable/generated/torch.take.html) - Add `dim` param default value description ## Test Result **Before** ![image](https://github.com/user-attachments/assets/720ce158-2bc1-48b5-a188-56fcc7188d96) **After** ![image](https://github.com/user-attachments/assets/05fe20bd-9476-4b97-ac2b-9b161d6532a1) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146474 Approved by: https://github.com/mikaylagawarecki --- torch/_torch_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 057ed0fe63e..2dd16890880 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -11100,8 +11100,8 @@ are designed to work with this function. See the examples below. Args: {input} - indices (tensor): the indices into :attr:`input`. Must have long dtype. - dim (int, optional): dimension to select along. + indices (LongTensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. Default: 0 Keyword args: {out} From 387c993c3b0bd04348eea14bcbef87755b0dae0f Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 7 Feb 2025 15:39:47 -0800 Subject: [PATCH 26/28] [ca] remove private API: _compiled_autograd_should_lift (#146720) Since the functional autograd + compiled autograd migration, we don't trace into nodes anymore, and everything is lifted. We can't support this flag which tries to inline make_fx style in CA initial pass. There's no more usage internally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146720 Approved by: https://github.com/zou3519 --- torch/_functorch/_aot_autograd/runtime_wrappers.py | 2 -- torch/autograd/function.py | 3 --- torch/csrc/autograd/python_function.cpp | 8 -------- torch/csrc/autograd/python_function.h | 2 -- 4 files changed, 15 deletions(-) diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index dc9e5af16da..10f1767167a 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1822,7 +1822,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment] maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta num_symints_saved_for_bw = num_symints_saved_for_bw_ - _compiled_autograd_should_lift = False _aot_id = aot_config.aot_id _lazy_backward_info = lazy_backward_info @@ -1989,7 +1988,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107 class CompiledFunctionBackward(torch.autograd.Function): # CompiledFunctionBackward is not yet supported in dynamo skipfiles - _compiled_autograd_should_lift = False _aot_id = aot_config.aot_id @staticmethod diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 1bcb2575145..219759ea37b 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -331,9 +331,6 @@ class FunctionMeta(type): name + "Backward", (BackwardCFunction,), {"_forward_cls": cls} ) backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined] - backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined] - "_compiled_autograd_should_lift", True - ) backward_fn._bw_module = None # type: ignore[attr-defined] if getattr(cls, "_lazy_backward_info", None): backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined] diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index dd0b7a927bf..67a307cd043 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -301,14 +301,6 @@ bool PyNode::is_aot_backward() const { return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); } -auto PyNode::compiled_autograd_should_lift() const -> bool { - pybind11::gil_scoped_acquire gil; - static PyObject* attr_name = - PyUnicode_InternFromString("_compiled_autograd_should_lift"); - THPObjectPtr should_lift(PyObject_GetAttr(obj, attr_name)); - return PyObject_IsTrue(should_lift.get()) == 1; -} - void PyNode::compiled_args(CompiledNodeArgs& args) { static PyObject* method_name = PyUnicode_InternFromString("_compiled_autograd_key"); diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 2f28c765ab0..f6f0979dc25 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -50,8 +50,6 @@ struct PyNode : public Node { const variable_list& inputs, SwapSavedVariables& saved) override; - bool compiled_autograd_should_lift() const; - // THPFunction this Function is wrapping. Owning! PyObject* obj; From effc5452748519bad995a5ffbb213523a76e0d9f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 7 Feb 2025 14:05:50 -0800 Subject: [PATCH 27/28] [DDP] Use NCCL allocated memory for gradient bucket (#146589) So that NVLink SHARP comes with zero-copy on H100+ platforms, for DDP applications. Less SM usage, less memory contention between NCCL kernel and compute kernels. Added env `DDP_DISABLE_COMM_MEM` as a back-out option: ``` An environment variable to disable comm-optimized memory pool. Default is 0, which means comm-optimized memory pool is enabled. Users can set it to 1 in case of seeing regression or OOM (because this comm MemPool may not share space with regular compute MemPool). ``` Differential Revision: [D69297766](https://our.internmc.facebook.com/intern/diff/D69297766) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146589 Approved by: https://github.com/syed-ahmed, https://github.com/c-p-i-o, https://github.com/fduwjj --- test/distributed/test_c10d_nccl.py | 4 +- torch/csrc/distributed/c10d/Backend.hpp | 15 ++++++ torch/csrc/distributed/c10d/NCCLUtils.cpp | 21 +++++--- torch/csrc/distributed/c10d/NCCLUtils.hpp | 5 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 51 ++++++++++++++++++- .../distributed/c10d/ProcessGroupNCCL.hpp | 10 ++++ torch/csrc/distributed/c10d/reducer.cpp | 34 ++++++++++++- 7 files changed, 127 insertions(+), 13 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2ab444a4b68..522b6815ada 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2007,7 +2007,7 @@ class DistributedDataParallelTest( replica_devices = [dev0] # Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device. layer_devs = dev0 - local_batch_size = 8 + local_batch_size = 16 self._test_grad_layout(replica_devices, layer_devs, local_batch_size) @requires_nccl() @@ -2021,7 +2021,7 @@ class DistributedDataParallelTest( replica_devices = None # Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device. layer_devs = [dev0] * 2 + [dev1] * 2 - local_batch_size = 8 + local_batch_size = 16 self._test_grad_layout(replica_devices, layer_devs, local_batch_size) @requires_nccl() diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 9d188c9c26d..ff83d687f8a 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -417,6 +417,21 @@ class TORCH_API Backend : public torch::CustomClassHolder { "Backend ", getBackendName(), " does not support getMemAllocator")); } + // Allocate tensor (aten::empty) from backend's communication-optimized memory + // pool + virtual at::Tensor allocateTensor(long size, at::TensorOptions options = {}) { + TORCH_CHECK( + false, + c10::str( + "Backend ", getBackendName(), " does not support allocateTensor")); + } + + // Returns true if backend supports tensor allocation + virtual bool supportsTensorAlloc() { + // Change to true in concrete backend if supported + return false; + } + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 9b5c5962479..99fc244af02 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -340,19 +340,26 @@ ncclResult_t NCCLComm::checkForNcclError() { #endif } -ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) { +ncclResult_t NCCLComm::registerSegment( + void* ptr, + size_t size, + bool errorOnRereg /*=true*/) { LockType lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always // maps to a unique handle and should not be registered before the current // ptr is deregistered and freed. - TORCH_CHECK( - registeredSegmentHandles_.count(ptr) == 0, - "Segment with ptr ", - ptr, - " has already been registered on ncclComm_ ", - ncclComm_); + if (registeredSegmentHandles_.count(ptr) > 0) { + TORCH_CHECK( + !errorOnRereg, + "Segment with ptr ", + ptr, + " has already been registered on ncclComm_ ", + ncclComm_); + // Skip below + return ncclSuccess; + } void* handle = nullptr; // Use getNcclComm to make sure comm is ready before calling nccl APIs diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 1ec81494856..c7cd0a30924 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -284,7 +284,10 @@ class NCCLComm { ncclResult_t checkForNcclError(); - ncclResult_t registerSegment(void* ptr, size_t size); + ncclResult_t registerSegment( + void* ptr, + size_t size, + bool errorOnRereg = true); ncclResult_t deregisterSegment(void* ptr); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d69fb2f5c36..cd9363ec337 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1175,7 +1175,8 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { ncclComm->registerSegment( // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(segmentInfo.address), - segmentInfo.total_size); + segmentInfo.total_size, + /*errorOnRereg=*/false); // ignores reregistration error } } @@ -1455,6 +1456,14 @@ void ProcessGroupNCCL::shutdown() { // Use long interval to avoid acquiring CPU too frequently ncclComm->waitReady(true); } + // Deregister memory pool after finalizing all collectives + if (memPool_) { + try { + deregisterMemPool(memPool_.get()); + } catch (...) { + LOG(ERROR) << logPrefix() << "Failed to deregister memory pool, ignoring"; + } + } // Tell watchdog to (1) flush its queue and (2) do not use comm objects // anymore because I am going to destroy them now LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread."; @@ -5422,6 +5431,46 @@ std::shared_ptr ProcessGroupNCCL::getMemAllocator() { return ncclMemAllocator; } +at::Tensor ProcessGroupNCCL::allocateTensor( + long size, + at::TensorOptions options) { + // Some checks + TORCH_CHECK_VALUE(options.has_device(), "Tensor options must include device"); + auto device = options.device(); + TORCH_CHECK_VALUE( + device.is_cuda(), + "NCCL tensor allocator expects cuda type but got " + c10::str(device)) + + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Create memory pool + if (!memPool_) { + // Needs a CUDAAllocator + auto allocator = + reinterpret_cast( + getMemAllocator().get()); + // Pool is created + memPool_ = std::make_unique(allocator); + LOG(INFO) << logPrefix() << "Created memory pool"; + } + + // Allocate tensor under this MemPool's context + auto ctx = c10::cuda::MemPoolContext(memPool_.get()); + c10::cuda::CUDACachingAllocator::beginAllocateToPool( + memPool_->device(), memPool_->id(), [](cudaStream_t) { return true; }); + at::Tensor tensor = at::empty({size}, options); + // Also need to ncclCommRegister the pool in case new segments are created; + // reregistration of old segments will be ignored + registerMemPool(memPool_.get()); + c10::cuda::CUDACachingAllocator::endAllocateToPool( + memPool_->device(), memPool_->id()); + c10::cuda::CUDACachingAllocator::releasePool( + memPool_->device(), memPool_->id()); + LOG(INFO) << logPrefix() << "Allocated tensor of size " << size + << " from memory pool"; + return tensor; +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 002b3a1a143..185d9bebe6e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -774,6 +774,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::shared_ptr getMemAllocator() override; + // Allocate tensor from communication-optimized memory pool + at::Tensor allocateTensor(long size, at::TensorOptions options = {}) override; + + bool supportsTensorAlloc() override { + return true; + } + // Performs NCCL user buffer registration for all buffers in // the given MemPool void registerMemPool(c10::cuda::MemPool* pool); @@ -1294,6 +1301,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Internal cached value: use NCCL non-blocking API mode or not. // Use `useNonblocking()` method instead of accessing this variable directly. std::optional useNonblocking_{std::nullopt}; + + // Communication-optimized memory pool associated with this PG + std::unique_ptr memPool_ = nullptr; }; // Dumps the NCCL comm traces and additional information about the Process diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 03c1380bfe7..800269fe14e 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1157,14 +1157,44 @@ void Reducer::initialize_buckets( offset += length; } - // Allocate the bucket's flattened `gradients` tensor. // Make gradient type in the reduced precision if mixed precision is // enabled. This ensures that the type is correct when e.g. rebuilding // buckets. if (mixed_precision_param_dtype_.has_value()) { options = options.dtype(mixed_precision_param_dtype_); } - bucket.gradients = at::empty({static_cast(offset)}, options); + + // Allocate the bucket's flattened `gradients` tensor. + auto bucketSize = static_cast(offset); + // Check if we can use comm-optimized memory pool to allocate tensor + c10::intrusive_ptr backend = nullptr; + // An environment variable to disable comm-optimized memory pool. + // Default is 0, which means comm-optimized memory pool is enabled. + // Users can set it to 1 in case of seeing regression or OOM (because this + // comm MemPool may not share space with regular compute MemPool). + bool ddpDisableCommMem = + (getCvarString({"DDP_DISABLE_COMM_MEM"}, "0") == "1"); + try { + backend = process_group_->getDefaultBackend(); + } catch (...) { + // Sometimes the backend type can be `UNDEFINED` rather than `NCCL` or + // `GLOO`. In this case, we just fall back to the regular way of + // creating tensor + LOG(INFO) + << "Reducer: default comm backend not found, skipping bucket memory optimization"; + } + if (ddpDisableCommMem == 0 && backend != nullptr && + backend->supportsTensorAlloc()) { + // Comm-optimized memory pool is available, use it to allocate tensor + LOG(INFO) + << "Reducer: found comm-optimized memory allocator, using it to create bucket"; + bucket.gradients = backend->allocateTensor(bucketSize, options); + } else { + // Plain creation of tensor + LOG(INFO) + << "Reducer: comm-optimized memory allocator not found, using regular one"; + bucket.gradients = at::empty({bucketSize}, options); + } // Note: "Gradient Layout Contract" // From c4d835fbab5cf173ee7a514315b63989aed44573 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Fri, 7 Feb 2025 11:39:42 -0800 Subject: [PATCH 28/28] [DTensor][conv] add DTensor convolution_backward op support for case where the input Tensor has requires_grad=False (#142278) Fixes #142058 ## Summary DTensor `convolution_backward` op throws exception when the input Tensor has `requires_grad=False` which happens if the conv layer is the first layer in the model. ATEN convolution_backward op Usually returns 3 Tensors (grad_input, grad_weight, grad_bias) and the `grad_input` is actually an Optional[Tensor] which can be `None` in the case mentioned above. However, the DTensor sharding propagation rule and corresponding TP conv backward implementation both assume that the `grad_input` would be existent. ## Fix allow the `grad_input` to be `None` for `convolution_backward` op. ## Test `pytest test/distributed/tensor/test_convolution_ops.py` ## Follow-up The current implementation of DTensor conv op also ignores `output_mask` and this may need further care. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142278 Approved by: https://github.com/bdhirsh --- .../tensor/test_convolution_ops.py | 28 ++++++++++++- torch/distributed/tensor/_ops/_conv_ops.py | 4 ++ torch/distributed/tensor/_sharding_prop.py | 39 +++++++++++++++---- torch/distributed/tensor/_tp_conv.py | 11 +++--- 4 files changed, 67 insertions(+), 15 deletions(-) diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index 867e89e778b..5d40a18f067 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -5,13 +5,15 @@ import copy import torch import torch.nn as nn -from torch.distributed._tensor import ( - DeviceMesh, +from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed.tensor import ( distribute_module, distribute_tensor, + DTensor, Replicate, Shard, ) +from torch.nn import functional as F from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -181,6 +183,28 @@ class DistConvolutionOpsTest(DTensorTestBase): f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}", ) + @with_comms + @skip_if_lt_x_gpu(2) + def test_conv_backward_none_grad_inp(self): + device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(self.world_size,) + ) + conv = nn.Conv2d(64, 64, 3, padding=1).train() + x = torch.randn(1, 64, 32, 32) + x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) + w = conv.weight + w_dt = torch.nn.Parameter(DTensor.from_local(w, device_mesh, [Replicate()])) + + b = conv.bias + b_dt = torch.nn.Parameter(DTensor.from_local(b, device_mesh, [Replicate()])) + + res = F.conv2d(x_dt, w_dt, b_dt, padding=1) + dres = torch.rand_like(res) + res.backward(dres) + self.assertTrue(w_dt.grad is not None) + self.assertTrue(b_dt.grad is not None) + self.assertTrue(x_dt.grad is None) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/tensor/_ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py index 9842ed1fde3..2198986d50c 100644 --- a/torch/distributed/tensor/_ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -105,4 +105,8 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: [0], tensor_meta=bias_tensor_meta, ) + # TODO: actually the output_mask is not respected here, we should + # set the corresponding spec to `None` if the output_mask is not `False` + # for a certain output Tensor. This also applies to the conv handler + # in torch/distributed/tensor/_tp_conv.py return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index cc5a80e2e82..e81957506d6 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -57,12 +57,14 @@ class ShardingPropagator: OpOverload, Callable[[DeviceMesh, OpSchema], StrategyType], ] = {} - # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop + # op map to save static argnum to decide to reuse sharding prop cache or + # re-run sharding prop self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {} self.propagate_op_sharding = LocalLRUCache( self.propagate_op_sharding_non_cached ) - # op map to save indices of shape (and stride) args which may need to be modified in sharding prop + # op map to save indices of shape (and stride) args which may need to be + # modified in sharding prop self.op_to_shape_and_stride_idx: dict[ OpOverload, Union[int, tuple[int, int]] ] = { @@ -171,10 +173,12 @@ class ShardingPropagator: # Either error due to ShardingPropagator or due to incorrect OutputSpec if not isinstance(output_tensor_meta, (tuple, list)): raise ValueError( - "ShardingPropagator error: output does not have an associated TensorMeta" + "ShardingPropagator error: output does not have an associated " + "TensorMeta" ) raise ValueError( - f"For the op {op.name()}, `output_specs` has 1 output which does not equal the " + f"For the op {op.name()}, `output_specs` has 1 output which does " + "not equal the " f"number of op outputs: {len(output_tensor_meta)}." ) output_specs.tensor_meta = output_tensor_meta @@ -183,16 +187,35 @@ class ShardingPropagator: output_specs ) != len(output_tensor_meta): raise ValueError( - f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " + f"For the op {op.name()}, `output_specs` has {len(output_specs)} " + "outputs which does not equal the " f"number of op outputs {_length(output_tensor_meta)}." ) + for i, spec in enumerate(output_specs): if isinstance(spec, DTensorSpec): output_tensor_meta_i = output_tensor_meta[i] if not isinstance(output_tensor_meta_i, TensorMeta): - raise ValueError( - f"ShardingPropagator error: output {i} does not have an associated TensorMeta" - ) + # NOTE: aten.convolution_backward.default is an exception and it + # needs extra handling because the first Tensor in the output + # tuple can be `None` if the input Tensor to convolution op has + # `requires_grad=False` (e.g. convolution layer is the first + # layer in the model). We explicitly allow its corresponding + # TensorMeta to be `None`. + if ( + op == aten.convolution_backward.default + and i == 0 + and output_tensor_meta_i is None + ): + assert isinstance(output_specs, list) + output_specs[i] = None + continue + else: + raise ValueError( + f"ShardingPropagator error: output {i} of {op.name()} " + "does not have an associated TensorMeta" + ) + spec.tensor_meta = output_tensor_meta_i def propagate(self, op_info: OpInfo) -> None: diff --git a/torch/distributed/tensor/_tp_conv.py b/torch/distributed/tensor/_tp_conv.py index e9ae126e3c5..f3e908f3e7a 100644 --- a/torch/distributed/tensor/_tp_conv.py +++ b/torch/distributed/tensor/_tp_conv.py @@ -215,12 +215,13 @@ def tp_convolution_backward( # step4 aggregate gradients for edge pixels grad_in_tensor = local_results[0] - grad_in_tensor = _ring_send_recv_aggregate( - grad_in_tensor, d1, d2, left, right, rank, size - ) + if grad_in_tensor is not None: + grad_in_tensor = _ring_send_recv_aggregate( + grad_in_tensor, d1, d2, left, right, rank, size + ) + local_results = list(local_results) + local_results[0] = grad_in_tensor - local_results = list(local_results) - local_results[0] = grad_in_tensor local_results = cast(tuple[object, ...], local_results) return local_results