[export] Refactor constrain_as_value and constrain_as_size (#106591)

Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2023-08-14 11:10:02 -07:00 committed by PyTorch MergeBot
parent d6c120d7f9
commit 20c5add133
19 changed files with 389 additions and 145 deletions

View file

@ -1,3 +1,4 @@
#include <limits>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
@ -15,21 +16,65 @@
#include <ATen/ops/_make_dep_token_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/sym_constrain_range_native.h>
#include <ATen/ops/sym_constrain_range_for_size_native.h>
#include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
#endif
namespace at {
namespace native {
void sym_constrain_range_cpu(
void sym_constrain_range(
const Scalar& size,
c10::optional<int64_t> min,
c10::optional<int64_t> max) {}
c10::optional<int64_t> max) {
Tensor _functional_sym_constrain_range_cpu(
int64_t min_val = min.has_value() ? min.value() : std::numeric_limits<int64_t>::min();
int64_t max_val = max.has_value() ? max.value() : std::numeric_limits<int64_t>::max();
int64_t size_as_int = size.toInt();
TORCH_CHECK(
max_val >= min_val,
"Max must be greater than or equal to min. Got min=",
min_val,
" max=",
max_val
);
TORCH_CHECK(
min_val <= size_as_int && size_as_int <= max_val,
"Invalid value range for ",
size_as_int,
" between [",
min_val,
", ",
max_val,
"]."
);
}
Tensor _functional_sym_constrain_range(
const Scalar& size,
c10::optional<int64_t> min,
c10::optional<int64_t> max,
const Tensor& dep_token) {
sym_constrain_range(size, min, max);
return dep_token.clone();
}
void sym_constrain_range_for_size(const Scalar& size, c10::optional<int64_t> min, c10::optional<int64_t> max) {
int64_t min_val = min.has_value() ? min.value() : 0;
if (max.has_value() && max.value() <= 2) {
TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value());
}
sym_constrain_range(size, min_val, max);
}
Tensor _functional_sym_constrain_range_for_size(
const Scalar& size,
c10::optional<int64_t> min,
c10::optional<int64_t> max,
const Tensor& dep_token) {
sym_constrain_range_for_size(size, min, max);
return dep_token.clone();
}

View file

@ -1,15 +0,0 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <c10/core/Scalar.h>
#include <c10/util/Optional.h>
namespace at {
namespace native {
void sym_constrain_range_cuda(
const Scalar& size,
c10::optional<int64_t> min = c10::nullopt,
c10::optional<int64_t> max = c10::nullopt) {}
} // namespace native
} // namespace at

View file

@ -181,14 +181,21 @@
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> ()
- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()
dispatch:
CPU: sym_constrain_range_cpu
CUDA: sym_constrain_range_cuda
CompositeExplicitAutograd: sym_constrain_range
- func: sym_constrain_range_for_size(Scalar size, *, int? min, int? max) -> ()
dispatch:
CompositeExplicitAutograd: sym_constrain_range_for_size
- func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
dispatch:
CPU: _functional_sym_constrain_range_cpu
CompositeExplicitAutograd: _functional_sym_constrain_range
- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
dispatch:
CompositeExplicitAutograd: _functional_sym_constrain_range_for_size
- func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:

View file

@ -2273,7 +2273,7 @@ def forward(self, x):
def test_export_preserve_constraints_as_metadata_scalar(self):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
constrain_as_size(b)
return torch.empty((b, y.shape[0]))
x = torch.tensor([3])
@ -2322,7 +2322,7 @@ def forward(self, x):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
constrain_as_size(b)
return torch.empty((b, y.shape[0]))
x = torch.tensor([3])
@ -2344,11 +2344,11 @@ def forward(self, x):
def test_export_with_inline_constraints(self):
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
constrain_as_value(a, 4, 7)
return torch.empty((a, 4))
with self.assertRaisesRegex(
torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]"
RuntimeError, r"Invalid value range for 20 between \[4, 7\]."
) as cm:
torch._export.export(f, (torch.tensor([20]),))
@ -2368,7 +2368,7 @@ def forward(self, x):
def test_export_with_inline_constraints_complex(self):
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
constrain_as_value(a, 4, 7)
empty = torch.empty((a, 4))
return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)

View file

@ -322,6 +322,7 @@ aten::_foreach_zero.out
aten::_foreach_zero_
aten::_functional_assert_async.msg
aten::_functional_sym_constrain_range
aten::_functional_sym_constrain_range_for_size
aten::_fused_adam
aten::_fused_adam.out
aten::_fused_adam_

View file

@ -5,8 +5,8 @@ import unittest
import torch
import torch._dynamo as torchdynamo
from torch._export import export, dynamic_dim, DEFAULT_EXPORT_DYNAMO_CONFIG
from torch._export.utils import register_dataclass_as_pytree_node
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._export.utils import register_dataclass_as_pytree_node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec
@ -21,7 +21,7 @@ class TestDynamismExpression(TestCase):
def f(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
constrain_as_size(b)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
@ -37,23 +37,6 @@ class TestDynamismExpression(TestCase):
self.assertTrue(torchdynamo.utils.same(ref, res))
def test_export_constraints_error(self):
def invalid_size(x):
b = x.item()
constrain_as_size(b, min=0, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"):
export(invalid_size, inp)
def invalid_input_conflict_with_inline_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([6]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"):
export(invalid_input_conflict_with_inline_constraints, inp)
def invalid_input_conflict_with_input_constraints(x):
return x + 1
@ -69,16 +52,15 @@ class TestDynamismExpression(TestCase):
constraints=inp_constraints,
)
def conflicting_constraints(x):
b = x.item()
constrain_as_size(b, min=2, max=3)
constrain_as_size(b, min=4, max=5)
constrain_as_size(b)
constrain_as_value(b, min=4, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid ranges"):
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 3 between \[4, 5\]"):
export(conflicting_constraints, inp)
def test_export_assume_static_by_default(self):
@ -222,25 +204,6 @@ class TestExport(TestCase):
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
em(x)
def test_export_constrain_static(self):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
c = y.dim()
constrain_as_value(c, min=1, max=3)
z = y[0:c]
return torch.empty((b, y.shape[0])), z
x = torch.tensor([3])
y = torch.randn([8, 8, 6])
example_inputs = (x, y)
constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
with self.assertRaisesRegex(
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
"on a value which we evaluated to have a static value of 3. "
):
export(f, example_inputs, {}, constraints)
def test_not_correct_dim(self):
def f(x):
return x.cos()
@ -588,5 +551,161 @@ class TestExport(TestCase):
# Intentionally not wrapping `inp` in a tuple to trigger the error
_ = export(fn, inp)
def test_constrain_value_with_no_default(self):
def fn(x, y):
n = x.max().item()
constrain_as_value(n)
return y + n
ep = export(fn, (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3))))
test_inp = (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3)))
self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp)))
def test_constrain_value_with_symfloat(self):
def fn(x, y):
n = x.max().item()
constrain_as_value(n)
return y + n
with self.assertRaisesRegex(torch._dynamo.exc.TorchRuntimeError, "Constraining SymFloat or Symbool is nyi"):
_ = export(fn, (torch.rand(2, 2), torch.rand(2, 3)))
def test_constrain_size_in_eager(self):
def fn(x, y):
n = x.max().item()
constrain_as_size(n)
return y + n
ep = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))))
test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp)))
def test_constrain_size_with_constrain_value(self):
def fn(x, y):
n = x.max().item()
constrain_as_value(n, 2, 10)
constrain_as_size(n)
return y + n
# Since we are using constrain_as_value, we expect to raise error when user
# passes in invalid tracing input
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."):
_ = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."):
_ = fn(torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
ep = export(fn, (torch.randint(3, 4, (2, 2)), torch.randint(3, 5, (2, 3))))
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint"):
test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
_ = ep(*test_inp)
def test_constrain_size_with_various_cases(self):
def case_1(x, y):
n = x.item()
constrain_as_size(n, min=0)
return y.sum() + torch.ones(n, 5).sum()
def case_2(x, y):
n = x.item()
constrain_as_size(n, min=0, max=6)
return y.sum() + torch.ones(n, 5).sum()
def case_3(x, y):
n = x.item()
constrain_as_size(n, min=0, max=1)
return y.sum() + torch.ones(n, 5).sum()
def case_4(x, y):
n = x.item()
constrain_as_size(n, min=2)
return y.sum() + torch.ones(n, 5).sum()
def case_5(x, y):
n = x.item()
constrain_as_size(n, min=1)
return y.sum() + torch.ones(n, 5).sum()
ep = export(case_1, (torch.tensor(1), torch.ones(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, inf\]."):
_ = ep(torch.tensor(-1), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for -1 between"):
_ = case_1(torch.tensor(-1), torch.randn(4, 5))
self.assertTrue(
torch.allclose(
ep(torch.tensor(1), torch.ones(4, 5)),
case_1(torch.tensor(1), torch.ones(4, 5)),
)
)
ep = export(case_2, (torch.tensor(5), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, 6\]."):
_ = ep(torch.tensor(7), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 7 between"):
_ = case_2(torch.tensor(7), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between \[0, 6\]."):
_ = export(case_2, (torch.tensor(9), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between"):
_ = case_2(torch.tensor(9), torch.randn(4, 5))
self.assertTrue(
torch.allclose(
ep(torch.tensor(5), torch.ones(4, 5)),
case_2(torch.tensor(5), torch.ones(4, 5)),
)
)
with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError,
"Maximum value to constrain_as_size must be greater than 2, but was 1"
):
_ = export(case_3, (torch.tensor(1), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, "Max value to constrain_range_for_size must be greater than 2. got: 1"):
_ = case_3(torch.tensor(1), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."):
_ = export(case_4, (torch.tensor(1), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."):
_ = case_4(torch.tensor(1), torch.randn(4, 5))
ep = export(case_4, (torch.tensor(5), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[2, inf\]."):
_ = ep(torch.tensor(1), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1"):
_ = case_4(torch.tensor(1), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0 between \[1, 9223372036854775807\]."):
_ = export(case_5, (torch.tensor(0), torch.randn(4, 5)))
self.assertTrue(
torch.allclose(
ep(torch.tensor(5), torch.ones(4, 5)),
case_4(torch.tensor(5), torch.ones(4, 5)),
)
)
ep = export(case_5, (torch.tensor(5), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[1, inf\]."):
_ = ep(torch.tensor(0), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0"):
_ = case_5(torch.tensor(0), torch.randn(4, 5))
self.assertTrue(
torch.allclose(
ep(torch.tensor(5), torch.ones(4, 5)),
case_5(torch.tensor(5), torch.ones(4, 5)),
)
)
if __name__ == '__main__':
run_tests()

View file

@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing import FileCheck
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export, dynamic_dim
from torch._export.constraints import constrain_as_value, constrain_as_size
from torch._export.constraints import constrain_as_value
from torch._export.passes import (
ReplaceViewOpsWithViewCopyOpsPass,
)
@ -346,7 +346,7 @@ class TestPasses(TestCase):
def test_functionalize_inline_contraints(self) -> None:
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
constrain_as_value(a, 4, 7)
return torch.empty((a, 4))
ep = torch._export.export(f, (torch.tensor([7]),))

View file

@ -314,6 +314,7 @@ ALLOW_LIST = [
("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)),
("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
]
ALLOW_LIST_COMPILED = [

View file

@ -10,9 +10,10 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
from torch._decomp import decomposition_table
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch.fx.experimental.symbolic_shapes import (
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
constrain_range, guard_int, GuardOnDataDependentSymNode
guard_int, GuardOnDataDependentSymNode
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
@ -1041,7 +1042,7 @@ def forward(self, a_1):
def test_item_to_constructor(self):
def f(a):
r = a.item()
constrain_range(r, min=2)
constrain_as_size(r)
return torch.empty(r)
r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
@ -1049,6 +1050,7 @@ def forward(self, a_1):
r, """\
def forward(self, a_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = None, max = None)
empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
return empty""" # noqa: B950
)
@ -1127,7 +1129,7 @@ def forward(self, crop_camera_1, mask_1):
for s in p.shape:
guard_int(s)
x = x[mask]
constrain_range(x.shape[0], min=1)
constrain_as_value(x.shape[0], min=1)
for p in params.values():
p.grad = None
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()

View file

@ -145,6 +145,11 @@ FILENAME_ALLOWLIST |= {
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
}
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "_export/constraints.py",
}
# TODO (zhxchen17) Make exportdb importable here.
FILENAME_ALLOWLIST |= set(
glob.glob(_module_dir(torch) + "_export/db/examples/*.py"),
) | {

View file

@ -363,7 +363,7 @@ def export(
# so we serialize them here instead of inside dynamo
gm.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
for k, v in fake_mode.shape_env.runtime_var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}

View file

@ -1,48 +1,60 @@
from typing import Optional, Callable, Union
from typing import Optional
import torch
from torch import SymInt, SymFloat
from torch._dynamo import allow_in_graph
from torch.fx.experimental.symbolic_shapes import constrain_range_int
from torch.utils._sympy.value_ranges import ValueRangeError
# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]`
# could cause type error during since `SymInt` or `SymFloat` will be used.
# Here manually specify the type explicitly.
sym_constrain_range: Callable[
[Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]],
None,
] = torch.sym_constrain_range # type: ignore[assignment]
# TODO: we want to hide this min/max stuff under some abstraction similar to
# DynamicDim
@allow_in_graph
def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
Add min/max constraint on the intermediate symbol at tracing time
Add min/max constraint on the intermediate symbol at tracing time. If called in eager mode,
it will still check if the input value is within the specified range.
"""
if not isinstance(symbol, SymInt):
constrain_range_int(symbol, min=min, max=max)
else:
sym_constrain_range(symbol, min, max)
return symbol
torch.sym_constrain_range(symbol, min=min, max=max)
# TODO: we want to hide this min/max stuff under some abstraction similar to
# DynamicDim
@allow_in_graph
def constrain_as_size(symbol, min: int = 2, max: Optional[int] = None):
"""
Add min/max constraint on the intermediate symbol which will be used as a size
def constrain_as_size(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
This indicates that a given int is size-like, and can be used in any context where a size is expected.
You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
# TODO: we should investigate turning off 0/1 specialization for unbacked
# SymInts
if min < 2:
raise ValueRangeError(
"Unable to set min size to be <= 2 because we specialize on 0/1 sizes."
)
return constrain_as_value(symbol, min, max)
This function has unusual semantics which distinguish it from constrain_as_value.
Specifically, at compile-time, we will unsoundly assume that the resulting int is always >= 2.
As a result, max value you pass in should always be greater than 2.
This makes it easier to use the unbacked int in size contexts, as we will often attempt to guard on a size being zero/one
(e.g., when computing the contiguity of a tensor, or testing if broadcasting can occur),
which will not work on unbacked SymInts. Assuming that the int is >= 2 allows us to
report False to these tests. Although this is technically unsound,
in practice we observe that if your program works for all sizes >= 2,
it probably works for zero and one too. The reason specifically assume size is >= 2 is because
lot of PyTorch code is specialized for 0 and 1 which could result in not general graphs.
At runtime, we only assert that the user provided min/max values are respected.
To demonstrate in a scenario, suppose you do
```
# Case 1
# This will assume symbol is between [2, inf) at compile time, but [0, inf) at runtime
constrain_as_size(symbol, min=0)
# Case 2
# This will assume symbol is between [2, N] at compile time, but [0, N] at runtime
constrain_as_size(symbol, min=0, max=N)
# Case 3
# This is not valid case as max is <= 2
constrain_as_size(symbol, min=0, max=1)
# Case 4
# This will assume symbol is between [2, inf) at compile time, AND [2, inf) at runtime
constrain_as_size(symbol, min=2)
# Case 5
# This will assume symbol is between [2, inf) at compile time, but [1, inf) at runtime
constrain_as_size(symbol, min=1)
```
"""
torch.sym_constrain_range_for_size(symbol, min=min, max=max)

View file

@ -806,7 +806,12 @@ class GraphModuleDeserializer:
if vr := self.symbol_name_to_range.get(val.expr_str):
symbolic_shapes._constrain_symbol_range(
self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type]
self.shape_env,
sym,
compiler_min=vr.lower, # type: ignore[arg-type]
compiler_max=vr.upper, # type: ignore[arg-type]
runtime_min=vr.lower, # type: ignore[arg-type]
runtime_max=vr.upper # type: ignore[arg-type]
)
return self.shape_env.create_symintnode(sym, hint=val.hint)

View file

@ -83,6 +83,11 @@ def functional_assert_async_msg_decomp(tensor, msg):
return
@register_decomposition([aten.sym_constrain_range_for_size.default])
def sym_constrain_range_for_size(symbol, *, min=None, max=None):
return
@register_decomposition([aten.clamp])
@pw_cast_for_opmath
def clamp(x, min=None, max=None):

View file

@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch._prims_common as utils
from torch import Tensor
from torch import SymBool, SymFloat, Tensor
from torch._decomp import (
_add_op_to_registry,
_convert_out_params,
@ -30,7 +30,10 @@ from torch._prims_common.wrappers import (
out_wrapper,
)
from torch._refs import _broadcast_shapes, _maybe_broadcast
from torch.fx.experimental.symbolic_shapes import constrain_range
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
constrain_range,
)
from torch.utils._pytree import tree_map
@ -424,6 +427,8 @@ def make_dep_token(
@register_meta(aten.sym_constrain_range.default)
def sym_constrain_range(size, min=None, max=None):
if isinstance(size, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat or Symbool is nyi")
constrain_range(size, min=min, max=max)
@ -433,6 +438,19 @@ def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
return dep_token
@register_meta(aten.sym_constrain_range_for_size.default)
def sym_constrain_range_for_size(size, min=None, max=None):
if isinstance(size, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat or Symbool is nyi")
_constrain_range_for_size(size, min=min, max=max)
@register_meta(aten._functional_sym_constrain_range_for_size.default)
def functional_sym_constrain_range_for_size(size, min, max, dep_token):
aten.sym_constrain_range_for_size(size, min=min, max=max)
return dep_token
@register_meta(aten._functional_assert_async.msg)
def functional_assert_async_meta(val, assert_msg, dep_token):
return dep_token

View file

@ -32,7 +32,7 @@ from torch import ( # noqa: F401
SymFloat,
SymInt,
)
from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode
from torch._guards import ShapeGuard, Source, TracingContext
from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
@ -304,13 +304,57 @@ def guard_scalar(a):
else:
raise AssertionError(f"unrecognized scalar {a}")
def _constrain_symbol_range(shape_env, s: sympy.Symbol, min: int, max: int):
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int, runtime_min: int, runtime_max: int):
if r := shape_env.var_to_range.get(s, None):
shape_env.var_to_range[s] = ValueRanges(
builtins.max(r.lower, min), builtins.min(r.upper, max)
builtins.max(r.lower, compiler_min), builtins.min(r.upper, compiler_max)
)
else:
shape_env.var_to_range[s] = ValueRanges(min, max)
shape_env.var_to_range[s] = ValueRanges(compiler_min, compiler_max)
if r := shape_env.runtime_var_to_range.get(s, None):
shape_env.runtime_var_to_range[s] = ValueRanges(
builtins.max(r.lower, runtime_min), builtins.min(r.upper, runtime_max)
)
else:
shape_env.runtime_var_to_range[s] = ValueRanges(runtime_min, runtime_max)
def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
"""
This function is NOT INTENDED to be used by itself.
"""
if isinstance(a, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat/SymBool is nyi")
assert isinstance(a, SymInt)
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
if min is None:
min = 0
if max is None:
max = sympy.oo
if max <= 2:
raise ValueError(f"Maximum value to constrain_as_size must be greater than 2, but was {max}")
if max < min:
raise ValueError(
"Maximum value to constrain_as_size can't be less than the specified min value, "
"received min={min} and max={max}"
)
compiler_min = 2 if min < 2 else min
_constrain_symbol_range(
a.node.shape_env,
a.node.expr,
compiler_min=compiler_min,
compiler_max=max,
runtime_min=min,
runtime_max=max
)
# inclusive both ways
def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
@ -350,8 +394,16 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
min = -sympy.oo
if max is None:
max = sympy.oo
if not isinstance(a, SymInt):
constrain_range_int(a, min=min, max=max)
if max < min:
raise ValueError(
"Maximum value to constrain_as_size can't be less than the specified min value, "
"received min={min} and max={max}"
)
if isinstance(a, int):
if not (min <= a <= max):
raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
return
if isinstance(a.node.expr, sympy.Integer):
@ -364,35 +416,15 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
# semantics that this is an "unchecked" assert (but it this actually
# something useful? Might be better to restrict only for unbacked
# SymInt).
_constrain_symbol_range(a.node.shape_env, a.node.expr, min, max)
_constrain_symbol_range(
a.node.shape_env,
a.node.expr,
compiler_min=min,
compiler_max=max,
runtime_min=min,
runtime_max=max
)
def constrain_range_int(a, *, min, max):
"""
Constrain range on concrete int value.
This can happens for the following scenarios:
- Eager mode execution and real int value is provided.
- During tracing the traced symbol is resolved as a static integer (see
PR #101655 for more details).
"""
if min is None:
min = -sympy.oo
if max is None:
max = sympy.oo
assert not isinstance(a, SymInt)
if not (min <= a <= max):
raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]")
if (
(fake_mode := detect_fake_mode()) is not None and
getattr(fake_mode, "shape_env", None) is not None
):
# If we are tracing with a fake mode then add this integer to the
# shape_env's var_to_range
sym_integer = sympy.Integer(a)
shape_env = fake_mode.shape_env
_constrain_symbol_range(shape_env, sym_integer, min, max)
shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack()
def constrain_unify(a, b):
"""
@ -1938,6 +1970,10 @@ class ShapeEnv:
# range may contain ints which may not actually appear in
# practice
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
# Maps symbolic ints to their min/max range for runtime checks.
# This is because we assume a graph generated with N=2 is general enough
# for N < 2. Therefore, it will be too strict to assert N=2 at runtime.
self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {}
# Maps symbolic ints to the guards that refine their lower/upper

View file

@ -36,6 +36,7 @@ _side_effectful_functions: Set[Callable] = {
_ops.aten._assert_async.msg,
_ops.aten.copy_.default,
_ops.aten.sym_constrain_range.default,
_ops.aten.sym_constrain_range_for_size.default,
_ops.profiler._record_function_enter,
_ops.profiler._record_function_enter_new,
_ops.profiler._record_function_exit}

View file

@ -186,6 +186,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.sym_min,
torch.sym_not,
torch.sym_constrain_range,
torch.sym_constrain_range_for_size,
torch.tril_indices,
torch.triu_indices,
torch.vander,

View file

@ -74,6 +74,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"record_stream", # no return
"sparse_dim", # returns an int
"sym_constrain_range", # no return
"sym_constrain_range_for_size", # no return
"_nested_tensor_storage_offsets", # returns a vector of ints
"_chunk_grad_outputs_efficient_attention", # returns a bool
"_fused_sdp_choice", # returns an int