mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Introduce sym_min and sym_max (#92107)
It turns out our old max/min implementation didn't do anything, because `__max__` and `__min__` are not actually magic methods in Python. So I give 'em the `sym_` treatment, similar to the other non-overrideable builtins. NB: I would like to use `sym_max` when computing contiguous strides but this appears to make `python test/functorch/test_aotdispatch.py -v -k test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32` run extremely slowly. Needs investigating. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/92107 Approved by: https://github.com/albanD, https://github.com/voznesenskym, https://github.com/Skylion007
This commit is contained in:
parent
b26efd0dd2
commit
6420fecdc4
11 changed files with 63 additions and 35 deletions
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
2ffe4a04df9d498e250153d931cadf9d92268510
|
||||
3ab8494305810d3c943f670bc6b028514942c7a0
|
||||
|
|
|
|||
|
|
@ -143,14 +143,14 @@ SymInt SymInt::min(const SymInt& sci) const {
|
|||
return std::min(data_, sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt(res[0]->min(res[1]));
|
||||
return SymInt(res[0]->sym_min(res[1]));
|
||||
}
|
||||
SymInt SymInt::max(const SymInt& sci) const {
|
||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
||||
return std::max(data_, sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt(res[0]->max(res[1]));
|
||||
return SymInt(res[0]->sym_max(res[1]));
|
||||
}
|
||||
|
||||
void SymInt::operator*=(const SymInt& sci) {
|
||||
|
|
|
|||
|
|
@ -76,10 +76,10 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
|||
virtual SymNode neg() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode min(const SymNode& other) {
|
||||
virtual SymNode sym_min(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode max(const SymNode& other) {
|
||||
virtual SymNode sym_max(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode clone() {
|
||||
|
|
|
|||
|
|
@ -618,8 +618,17 @@ Utilities
|
|||
is_warn_always_enabled
|
||||
vmap
|
||||
_assert
|
||||
|
||||
Symbolic Numbers
|
||||
----------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
sym_float
|
||||
sym_int
|
||||
sym_max
|
||||
sym_min
|
||||
|
||||
Optimizations
|
||||
-------------
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import itertools
|
|||
import random
|
||||
import contextlib
|
||||
import math
|
||||
import builtins
|
||||
import atexit
|
||||
import io
|
||||
import os
|
||||
|
|
@ -508,11 +507,7 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
# These functions might return plain int/float
|
||||
has_valid_downcast = fn in ["min", "max"]
|
||||
if fn in symbolic_shapes.magic_methods_on_builtins:
|
||||
lambda_apply = getattr(builtins, fn)
|
||||
elif fn in symbolic_shapes.magic_methods_on_math:
|
||||
if fn in symbolic_shapes.magic_methods_on_math:
|
||||
lambda_apply = getattr(math, fn)
|
||||
elif fn in symbolic_shapes.magic_methods_on_submodule:
|
||||
lambda_apply = getattr(symbolic_shapes, fn)
|
||||
|
|
@ -529,16 +524,10 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
tp = "float" if any(isinstance(i, float) for i in [inp1, inp2]) else "int"
|
||||
|
||||
def guard_fn(v):
|
||||
try:
|
||||
if fn in symbolic_shapes.always_bool_magic_methods:
|
||||
return bool(v)
|
||||
else:
|
||||
return getattr(v.node, f"guard_{tp}")("", 0)
|
||||
except Exception as e:
|
||||
if has_valid_downcast:
|
||||
return v
|
||||
else:
|
||||
raise e
|
||||
if fn in symbolic_shapes.always_bool_magic_methods:
|
||||
return bool(v)
|
||||
else:
|
||||
return getattr(v.node, f"guard_{tp}")("", 0)
|
||||
|
||||
# Get reference result
|
||||
with maybe_xfail(inp1, inp2):
|
||||
|
|
|
|||
2
third_party/ideep
vendored
2
third_party/ideep
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit e7925bc7c260e6c4481ccb53b7d29c59a901a05d
|
||||
Subproject commit 7201315611bebbb041f2ca7a0cdb3c6f4ccd17a3
|
||||
|
|
@ -50,7 +50,8 @@ __all__ = [
|
|||
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
|
||||
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||
'sym_int', 'sym_float', 'compile', 'vmap']
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap'
|
||||
]
|
||||
|
||||
################################################################################
|
||||
# Load the extension module
|
||||
|
|
@ -259,6 +260,12 @@ class SymInt:
|
|||
def __ge__(self, other) -> builtins.bool:
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_max__(self, other):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_min__(self, other):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_float__(self):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
|
|
@ -303,6 +310,12 @@ class SymFloat:
|
|||
def __ge__(self, other) -> builtins.bool:
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_max__(self, other):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_min__(self, other):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __repr__(self):
|
||||
return self.node.str()
|
||||
|
||||
|
|
@ -343,6 +356,22 @@ def sym_int(a):
|
|||
return _sym_floor(a) if a > 0 else _sym_ceil(a)
|
||||
return py_int(a) # type: ignore[operator]
|
||||
|
||||
def sym_max(a, b):
|
||||
""" SymInt-aware utility for max()."""
|
||||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_max__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
return b.__sym_max__(a)
|
||||
return builtins.max(a, b) # type: ignore[operator]
|
||||
|
||||
def sym_min(a, b):
|
||||
""" SymInt-aware utility for max()."""
|
||||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_min__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
return b.__sym_min__(a)
|
||||
return builtins.min(a, b) # type: ignore[operator]
|
||||
|
||||
# Check to see if we can load C extensions, and if not provide some guidance
|
||||
# on what the problem might be.
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1174,8 +1174,8 @@ void initJITBindings(PyObject* module) {
|
|||
SYMNODE_BINARY(lt)
|
||||
SYMNODE_BINARY(le)
|
||||
SYMNODE_BINARY(ge)
|
||||
SYMNODE_BINARY(min)
|
||||
SYMNODE_BINARY(max)
|
||||
SYMNODE_BINARY(sym_min)
|
||||
SYMNODE_BINARY(sym_max)
|
||||
SYMNODE_UNARY(ceil)
|
||||
SYMNODE_UNARY(floor)
|
||||
SYMNODE_UNARY(neg)
|
||||
|
|
|
|||
|
|
@ -141,10 +141,10 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
|||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode min(const c10::SymNode& other) override {
|
||||
c10::SymNode sym_min(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
c10::SymNode max(const c10::SymNode& other) override {
|
||||
c10::SymNode sym_max(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Set, Dict, List, Type, Optional, cast
|
|||
import sys
|
||||
import itertools
|
||||
import operator
|
||||
import builtins
|
||||
import math
|
||||
import functools
|
||||
import threading
|
||||
|
|
@ -15,7 +14,7 @@ import textwrap
|
|||
import logging
|
||||
|
||||
# NB: The sym_* functions are used via getattr() and must be imported here.
|
||||
from torch import SymInt, SymFloat, sym_float, sym_int # noqa: F401
|
||||
from torch import SymInt, SymFloat, sym_float, sym_int, sym_max, sym_min # noqa: F401
|
||||
from torch._guards import ShapeGuard, Source
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -265,8 +264,8 @@ magic_methods = {
|
|||
'sym_float': lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals
|
||||
'ceil': lambda a: sympy.ceiling(a),
|
||||
'neg': lambda a: -a,
|
||||
'min': lambda a, b: sympy.Min(a, b),
|
||||
'max': lambda a, b: sympy.Max(a, b),
|
||||
'sym_min': lambda a, b: sympy.Min(a, b),
|
||||
'sym_max': lambda a, b: sympy.Max(a, b),
|
||||
'sym_sqrt': lambda a: sympy.sqrt(a),
|
||||
}
|
||||
|
||||
|
|
@ -278,9 +277,8 @@ unary_magic_methods = {
|
|||
'sym_sqrt',
|
||||
}
|
||||
|
||||
magic_methods_on_builtins = {"min", "max"}
|
||||
magic_methods_on_math = {"ceil", "floor"}
|
||||
magic_methods_on_submodule = {"sym_float", "sym_sqrt"}
|
||||
magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max"}
|
||||
|
||||
always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt"}
|
||||
always_int_magic_methods = {"ceil", "floor"}
|
||||
|
|
@ -301,9 +299,10 @@ def _make_node_magic(method, func):
|
|||
func = lru_cache(256)(func)
|
||||
|
||||
def binary_magic_impl(self, other):
|
||||
if method in magic_methods_on_builtins:
|
||||
op = getattr(builtins, method)
|
||||
if method in magic_methods_on_submodule:
|
||||
op = getattr(sys.modules[__name__], method)
|
||||
else:
|
||||
assert method not in magic_methods_on_math
|
||||
op = getattr(operator, method)
|
||||
if SYM_FUNCTION_MODE:
|
||||
r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
||||
|
|
|
|||
|
|
@ -181,6 +181,8 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
torch.sparse_bsc_tensor,
|
||||
torch.sym_float,
|
||||
torch.sym_int,
|
||||
torch.sym_max,
|
||||
torch.sym_min,
|
||||
torch.tril_indices,
|
||||
torch.triu_indices,
|
||||
torch.vander,
|
||||
|
|
|
|||
Loading…
Reference in a new issue