mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] Refactor constrain_as_value and constrain_as_size (#106591)
Some notable changes: 1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2. 2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591 Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
This commit is contained in:
parent
d6c120d7f9
commit
20c5add133
19 changed files with 389 additions and 145 deletions
|
|
@ -1,3 +1,4 @@
|
|||
#include <limits>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/Device.h>
|
||||
|
|
@ -15,21 +16,65 @@
|
|||
#include <ATen/ops/_make_dep_token_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/sym_constrain_range_native.h>
|
||||
#include <ATen/ops/sym_constrain_range_for_size_native.h>
|
||||
#include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void sym_constrain_range_cpu(
|
||||
void sym_constrain_range(
|
||||
const Scalar& size,
|
||||
c10::optional<int64_t> min,
|
||||
c10::optional<int64_t> max) {}
|
||||
c10::optional<int64_t> max) {
|
||||
|
||||
Tensor _functional_sym_constrain_range_cpu(
|
||||
int64_t min_val = min.has_value() ? min.value() : std::numeric_limits<int64_t>::min();
|
||||
int64_t max_val = max.has_value() ? max.value() : std::numeric_limits<int64_t>::max();
|
||||
int64_t size_as_int = size.toInt();
|
||||
|
||||
TORCH_CHECK(
|
||||
max_val >= min_val,
|
||||
"Max must be greater than or equal to min. Got min=",
|
||||
min_val,
|
||||
" max=",
|
||||
max_val
|
||||
);
|
||||
|
||||
TORCH_CHECK(
|
||||
min_val <= size_as_int && size_as_int <= max_val,
|
||||
"Invalid value range for ",
|
||||
size_as_int,
|
||||
" between [",
|
||||
min_val,
|
||||
", ",
|
||||
max_val,
|
||||
"]."
|
||||
);
|
||||
}
|
||||
|
||||
Tensor _functional_sym_constrain_range(
|
||||
const Scalar& size,
|
||||
c10::optional<int64_t> min,
|
||||
c10::optional<int64_t> max,
|
||||
const Tensor& dep_token) {
|
||||
sym_constrain_range(size, min, max);
|
||||
return dep_token.clone();
|
||||
}
|
||||
|
||||
void sym_constrain_range_for_size(const Scalar& size, c10::optional<int64_t> min, c10::optional<int64_t> max) {
|
||||
int64_t min_val = min.has_value() ? min.value() : 0;
|
||||
if (max.has_value() && max.value() <= 2) {
|
||||
TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value());
|
||||
}
|
||||
sym_constrain_range(size, min_val, max);
|
||||
}
|
||||
|
||||
Tensor _functional_sym_constrain_range_for_size(
|
||||
const Scalar& size,
|
||||
c10::optional<int64_t> min,
|
||||
c10::optional<int64_t> max,
|
||||
const Tensor& dep_token) {
|
||||
sym_constrain_range_for_size(size, min, max);
|
||||
return dep_token.clone();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void sym_constrain_range_cuda(
|
||||
const Scalar& size,
|
||||
c10::optional<int64_t> min = c10::nullopt,
|
||||
c10::optional<int64_t> max = c10::nullopt) {}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
@ -181,14 +181,21 @@
|
|||
|
||||
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
|
||||
|
||||
- func: sym_constrain_range(Scalar size, int? min=None, int? max=None) -> ()
|
||||
- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()
|
||||
dispatch:
|
||||
CPU: sym_constrain_range_cpu
|
||||
CUDA: sym_constrain_range_cuda
|
||||
CompositeExplicitAutograd: sym_constrain_range
|
||||
|
||||
- func: sym_constrain_range_for_size(Scalar size, *, int? min, int? max) -> ()
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: sym_constrain_range_for_size
|
||||
|
||||
- func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
|
||||
dispatch:
|
||||
CPU: _functional_sym_constrain_range_cpu
|
||||
CompositeExplicitAutograd: _functional_sym_constrain_range
|
||||
|
||||
- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _functional_sym_constrain_range_for_size
|
||||
|
||||
- func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -2273,7 +2273,7 @@ def forward(self, x):
|
|||
def test_export_preserve_constraints_as_metadata_scalar(self):
|
||||
def f(x, y):
|
||||
b = x.item()
|
||||
constrain_as_size(b, min=2, max=5)
|
||||
constrain_as_size(b)
|
||||
return torch.empty((b, y.shape[0]))
|
||||
|
||||
x = torch.tensor([3])
|
||||
|
|
@ -2322,7 +2322,7 @@ def forward(self, x):
|
|||
|
||||
def f(x, y):
|
||||
b = x.item()
|
||||
constrain_as_size(b, min=2, max=5)
|
||||
constrain_as_size(b)
|
||||
return torch.empty((b, y.shape[0]))
|
||||
|
||||
x = torch.tensor([3])
|
||||
|
|
@ -2344,11 +2344,11 @@ def forward(self, x):
|
|||
def test_export_with_inline_constraints(self):
|
||||
def f(x):
|
||||
a = x.item()
|
||||
constrain_as_size(a, 4, 7)
|
||||
constrain_as_value(a, 4, 7)
|
||||
return torch.empty((a, 4))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError, r"Invalid value 20 for range \[4:7\]"
|
||||
RuntimeError, r"Invalid value range for 20 between \[4, 7\]."
|
||||
) as cm:
|
||||
torch._export.export(f, (torch.tensor([20]),))
|
||||
|
||||
|
|
@ -2368,7 +2368,7 @@ def forward(self, x):
|
|||
def test_export_with_inline_constraints_complex(self):
|
||||
def f(x):
|
||||
a = x.item()
|
||||
constrain_as_size(a, 4, 7)
|
||||
constrain_as_value(a, 4, 7)
|
||||
empty = torch.empty((a, 4))
|
||||
|
||||
return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)
|
||||
|
|
|
|||
|
|
@ -322,6 +322,7 @@ aten::_foreach_zero.out
|
|||
aten::_foreach_zero_
|
||||
aten::_functional_assert_async.msg
|
||||
aten::_functional_sym_constrain_range
|
||||
aten::_functional_sym_constrain_range_for_size
|
||||
aten::_fused_adam
|
||||
aten::_fused_adam.out
|
||||
aten::_fused_adam_
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import unittest
|
|||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch._export import export, dynamic_dim, DEFAULT_EXPORT_DYNAMO_CONFIG
|
||||
from torch._export.utils import register_dataclass_as_pytree_node
|
||||
from torch._export.constraints import constrain_as_size, constrain_as_value
|
||||
from torch._export.utils import register_dataclass_as_pytree_node
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec
|
||||
|
|
@ -21,7 +21,7 @@ class TestDynamismExpression(TestCase):
|
|||
|
||||
def f(x):
|
||||
b = x.item()
|
||||
constrain_as_size(b, min=2, max=5)
|
||||
constrain_as_size(b)
|
||||
return torch.full((b, 1), 1)
|
||||
|
||||
inp = (torch.tensor([3]),)
|
||||
|
|
@ -37,23 +37,6 @@ class TestDynamismExpression(TestCase):
|
|||
self.assertTrue(torchdynamo.utils.same(ref, res))
|
||||
|
||||
def test_export_constraints_error(self):
|
||||
def invalid_size(x):
|
||||
b = x.item()
|
||||
constrain_as_size(b, min=0, max=5)
|
||||
return torch.full((b, 1), 1)
|
||||
|
||||
inp = (torch.tensor([3]),)
|
||||
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Unable to set min size"):
|
||||
export(invalid_size, inp)
|
||||
|
||||
def invalid_input_conflict_with_inline_constraints(x):
|
||||
b = x.item()
|
||||
constrain_as_size(b, min=2, max=5)
|
||||
return torch.full((b, 1), 1)
|
||||
|
||||
inp = (torch.tensor([6]),)
|
||||
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid value 6 for range"):
|
||||
export(invalid_input_conflict_with_inline_constraints, inp)
|
||||
|
||||
def invalid_input_conflict_with_input_constraints(x):
|
||||
return x + 1
|
||||
|
|
@ -69,16 +52,15 @@ class TestDynamismExpression(TestCase):
|
|||
constraints=inp_constraints,
|
||||
)
|
||||
|
||||
|
||||
def conflicting_constraints(x):
|
||||
b = x.item()
|
||||
constrain_as_size(b, min=2, max=3)
|
||||
constrain_as_size(b, min=4, max=5)
|
||||
constrain_as_size(b)
|
||||
constrain_as_value(b, min=4, max=5)
|
||||
return torch.full((b, 1), 1)
|
||||
|
||||
inp = (torch.tensor([3]),)
|
||||
|
||||
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Invalid ranges"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 3 between \[4, 5\]"):
|
||||
export(conflicting_constraints, inp)
|
||||
|
||||
def test_export_assume_static_by_default(self):
|
||||
|
|
@ -222,25 +204,6 @@ class TestExport(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
|
||||
em(x)
|
||||
|
||||
def test_export_constrain_static(self):
|
||||
def f(x, y):
|
||||
b = x.item()
|
||||
constrain_as_size(b, min=2, max=5)
|
||||
c = y.dim()
|
||||
constrain_as_value(c, min=1, max=3)
|
||||
z = y[0:c]
|
||||
return torch.empty((b, y.shape[0])), z
|
||||
|
||||
x = torch.tensor([3])
|
||||
y = torch.randn([8, 8, 6])
|
||||
example_inputs = (x, y)
|
||||
constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
|
||||
with self.assertRaisesRegex(
|
||||
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
|
||||
"on a value which we evaluated to have a static value of 3. "
|
||||
):
|
||||
export(f, example_inputs, {}, constraints)
|
||||
|
||||
def test_not_correct_dim(self):
|
||||
def f(x):
|
||||
return x.cos()
|
||||
|
|
@ -588,5 +551,161 @@ class TestExport(TestCase):
|
|||
# Intentionally not wrapping `inp` in a tuple to trigger the error
|
||||
_ = export(fn, inp)
|
||||
|
||||
def test_constrain_value_with_no_default(self):
|
||||
def fn(x, y):
|
||||
n = x.max().item()
|
||||
constrain_as_value(n)
|
||||
return y + n
|
||||
|
||||
ep = export(fn, (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3))))
|
||||
test_inp = (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3)))
|
||||
self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp)))
|
||||
|
||||
def test_constrain_value_with_symfloat(self):
|
||||
def fn(x, y):
|
||||
n = x.max().item()
|
||||
constrain_as_value(n)
|
||||
return y + n
|
||||
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.TorchRuntimeError, "Constraining SymFloat or Symbool is nyi"):
|
||||
_ = export(fn, (torch.rand(2, 2), torch.rand(2, 3)))
|
||||
|
||||
def test_constrain_size_in_eager(self):
|
||||
def fn(x, y):
|
||||
n = x.max().item()
|
||||
constrain_as_size(n)
|
||||
return y + n
|
||||
|
||||
ep = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))))
|
||||
test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
|
||||
self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp)))
|
||||
|
||||
def test_constrain_size_with_constrain_value(self):
|
||||
def fn(x, y):
|
||||
n = x.max().item()
|
||||
constrain_as_value(n, 2, 10)
|
||||
constrain_as_size(n)
|
||||
return y + n
|
||||
|
||||
# Since we are using constrain_as_value, we expect to raise error when user
|
||||
# passes in invalid tracing input
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."):
|
||||
_ = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."):
|
||||
_ = fn(torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
|
||||
|
||||
ep = export(fn, (torch.randint(3, 4, (2, 2)), torch.randint(3, 5, (2, 3))))
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint"):
|
||||
test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
|
||||
_ = ep(*test_inp)
|
||||
|
||||
def test_constrain_size_with_various_cases(self):
|
||||
|
||||
def case_1(x, y):
|
||||
n = x.item()
|
||||
constrain_as_size(n, min=0)
|
||||
return y.sum() + torch.ones(n, 5).sum()
|
||||
|
||||
def case_2(x, y):
|
||||
n = x.item()
|
||||
constrain_as_size(n, min=0, max=6)
|
||||
return y.sum() + torch.ones(n, 5).sum()
|
||||
|
||||
def case_3(x, y):
|
||||
n = x.item()
|
||||
constrain_as_size(n, min=0, max=1)
|
||||
return y.sum() + torch.ones(n, 5).sum()
|
||||
|
||||
def case_4(x, y):
|
||||
n = x.item()
|
||||
constrain_as_size(n, min=2)
|
||||
return y.sum() + torch.ones(n, 5).sum()
|
||||
|
||||
def case_5(x, y):
|
||||
n = x.item()
|
||||
constrain_as_size(n, min=1)
|
||||
return y.sum() + torch.ones(n, 5).sum()
|
||||
|
||||
ep = export(case_1, (torch.tensor(1), torch.ones(4, 5)))
|
||||
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, inf\]."):
|
||||
_ = ep(torch.tensor(-1), torch.randn(4, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for -1 between"):
|
||||
_ = case_1(torch.tensor(-1), torch.randn(4, 5))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
ep(torch.tensor(1), torch.ones(4, 5)),
|
||||
case_1(torch.tensor(1), torch.ones(4, 5)),
|
||||
)
|
||||
)
|
||||
|
||||
ep = export(case_2, (torch.tensor(5), torch.randn(4, 5)))
|
||||
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[0, 6\]."):
|
||||
_ = ep(torch.tensor(7), torch.randn(4, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 7 between"):
|
||||
_ = case_2(torch.tensor(7), torch.randn(4, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between \[0, 6\]."):
|
||||
_ = export(case_2, (torch.tensor(9), torch.randn(4, 5)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between"):
|
||||
_ = case_2(torch.tensor(9), torch.randn(4, 5))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
ep(torch.tensor(5), torch.ones(4, 5)),
|
||||
case_2(torch.tensor(5), torch.ones(4, 5)),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"Maximum value to constrain_as_size must be greater than 2, but was 1"
|
||||
):
|
||||
_ = export(case_3, (torch.tensor(1), torch.randn(4, 5)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Max value to constrain_range_for_size must be greater than 2. got: 1"):
|
||||
_ = case_3(torch.tensor(1), torch.randn(4, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."):
|
||||
_ = export(case_4, (torch.tensor(1), torch.randn(4, 5)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."):
|
||||
_ = case_4(torch.tensor(1), torch.randn(4, 5))
|
||||
|
||||
ep = export(case_4, (torch.tensor(5), torch.randn(4, 5)))
|
||||
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[2, inf\]."):
|
||||
_ = ep(torch.tensor(1), torch.randn(4, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1"):
|
||||
_ = case_4(torch.tensor(1), torch.randn(4, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0 between \[1, 9223372036854775807\]."):
|
||||
_ = export(case_5, (torch.tensor(0), torch.randn(4, 5)))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
ep(torch.tensor(5), torch.ones(4, 5)),
|
||||
case_4(torch.tensor(5), torch.ones(4, 5)),
|
||||
)
|
||||
)
|
||||
|
||||
ep = export(case_5, (torch.tensor(5), torch.randn(4, 5)))
|
||||
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[1, inf\]."):
|
||||
_ = ep(torch.tensor(0), torch.randn(4, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0"):
|
||||
_ = case_5(torch.tensor(0), torch.randn(4, 5))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
ep(torch.tensor(5), torch.ones(4, 5)),
|
||||
case_5(torch.tensor(5), torch.ones(4, 5)),
|
||||
)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
|||
from torch.testing import FileCheck
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch._export import export, dynamic_dim
|
||||
from torch._export.constraints import constrain_as_value, constrain_as_size
|
||||
from torch._export.constraints import constrain_as_value
|
||||
from torch._export.passes import (
|
||||
ReplaceViewOpsWithViewCopyOpsPass,
|
||||
)
|
||||
|
|
@ -346,7 +346,7 @@ class TestPasses(TestCase):
|
|||
def test_functionalize_inline_contraints(self) -> None:
|
||||
def f(x):
|
||||
a = x.item()
|
||||
constrain_as_size(a, 4, 7)
|
||||
constrain_as_value(a, 4, 7)
|
||||
return torch.empty((a, 4))
|
||||
|
||||
ep = torch._export.export(f, (torch.tensor([7]),))
|
||||
|
|
|
|||
|
|
@ -314,6 +314,7 @@ ALLOW_LIST = [
|
|||
("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)),
|
||||
("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
|
||||
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
|
||||
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
|||
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
|
||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
|
||||
from torch._decomp import decomposition_table
|
||||
from torch._export.constraints import constrain_as_size, constrain_as_value
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
|
||||
constrain_range, guard_int, GuardOnDataDependentSymNode
|
||||
guard_int, GuardOnDataDependentSymNode
|
||||
)
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db
|
||||
|
|
@ -1041,7 +1042,7 @@ def forward(self, a_1):
|
|||
def test_item_to_constructor(self):
|
||||
def f(a):
|
||||
r = a.item()
|
||||
constrain_range(r, min=2)
|
||||
constrain_as_size(r)
|
||||
return torch.empty(r)
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
|
||||
|
|
@ -1049,6 +1050,7 @@ def forward(self, a_1):
|
|||
r, """\
|
||||
def forward(self, a_1):
|
||||
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None
|
||||
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = None, max = None)
|
||||
empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
|
||||
return empty""" # noqa: B950
|
||||
)
|
||||
|
|
@ -1127,7 +1129,7 @@ def forward(self, crop_camera_1, mask_1):
|
|||
for s in p.shape:
|
||||
guard_int(s)
|
||||
x = x[mask]
|
||||
constrain_range(x.shape[0], min=1)
|
||||
constrain_as_value(x.shape[0], min=1)
|
||||
for p in params.values():
|
||||
p.grad = None
|
||||
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
|
||||
|
|
|
|||
|
|
@ -145,6 +145,11 @@ FILENAME_ALLOWLIST |= {
|
|||
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
|
||||
}
|
||||
|
||||
FILENAME_ALLOWLIST |= {
|
||||
_module_dir(torch) + "_export/constraints.py",
|
||||
}
|
||||
|
||||
# TODO (zhxchen17) Make exportdb importable here.
|
||||
FILENAME_ALLOWLIST |= set(
|
||||
glob.glob(_module_dir(torch) + "_export/db/examples/*.py"),
|
||||
) | {
|
||||
|
|
|
|||
|
|
@ -363,7 +363,7 @@ def export(
|
|||
# so we serialize them here instead of inside dynamo
|
||||
gm.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
for k, v in fake_mode.shape_env.runtime_var_to_range.items()
|
||||
if re.match(r"^[if]\d+$", str(k))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,48 +1,60 @@
|
|||
from typing import Optional, Callable, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import SymInt, SymFloat
|
||||
from torch._dynamo import allow_in_graph
|
||||
from torch.fx.experimental.symbolic_shapes import constrain_range_int
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
||||
# `Scalar` type used in native_functions.ymal will be translated to `Union[Number, _complex]`
|
||||
# could cause type error during since `SymInt` or `SymFloat` will be used.
|
||||
# Here manually specify the type explicitly.
|
||||
sym_constrain_range: Callable[
|
||||
[Union[int, float, SymInt, SymFloat], Optional[int], Optional[int]],
|
||||
None,
|
||||
] = torch.sym_constrain_range # type: ignore[assignment]
|
||||
|
||||
|
||||
# TODO: we want to hide this min/max stuff under some abstraction similar to
|
||||
# DynamicDim
|
||||
@allow_in_graph
|
||||
def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None):
|
||||
"""
|
||||
Add min/max constraint on the intermediate symbol at tracing time
|
||||
Add min/max constraint on the intermediate symbol at tracing time. If called in eager mode,
|
||||
it will still check if the input value is within the specified range.
|
||||
"""
|
||||
|
||||
if not isinstance(symbol, SymInt):
|
||||
constrain_range_int(symbol, min=min, max=max)
|
||||
else:
|
||||
sym_constrain_range(symbol, min, max)
|
||||
|
||||
return symbol
|
||||
torch.sym_constrain_range(symbol, min=min, max=max)
|
||||
|
||||
|
||||
# TODO: we want to hide this min/max stuff under some abstraction similar to
|
||||
# DynamicDim
|
||||
@allow_in_graph
|
||||
def constrain_as_size(symbol, min: int = 2, max: Optional[int] = None):
|
||||
"""
|
||||
Add min/max constraint on the intermediate symbol which will be used as a size
|
||||
def constrain_as_size(symbol, min: Optional[int] = None, max: Optional[int] = None):
|
||||
"""
|
||||
This indicates that a given int is size-like, and can be used in any context where a size is expected.
|
||||
You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
|
||||
which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
|
||||
GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
|
||||
|
||||
# TODO: we should investigate turning off 0/1 specialization for unbacked
|
||||
# SymInts
|
||||
if min < 2:
|
||||
raise ValueRangeError(
|
||||
"Unable to set min size to be <= 2 because we specialize on 0/1 sizes."
|
||||
)
|
||||
return constrain_as_value(symbol, min, max)
|
||||
This function has unusual semantics which distinguish it from constrain_as_value.
|
||||
Specifically, at compile-time, we will unsoundly assume that the resulting int is always >= 2.
|
||||
As a result, max value you pass in should always be greater than 2.
|
||||
This makes it easier to use the unbacked int in size contexts, as we will often attempt to guard on a size being zero/one
|
||||
(e.g., when computing the contiguity of a tensor, or testing if broadcasting can occur),
|
||||
which will not work on unbacked SymInts. Assuming that the int is >= 2 allows us to
|
||||
report False to these tests. Although this is technically unsound,
|
||||
in practice we observe that if your program works for all sizes >= 2,
|
||||
it probably works for zero and one too. The reason specifically assume size is >= 2 is because
|
||||
lot of PyTorch code is specialized for 0 and 1 which could result in not general graphs.
|
||||
At runtime, we only assert that the user provided min/max values are respected.
|
||||
|
||||
To demonstrate in a scenario, suppose you do
|
||||
```
|
||||
# Case 1
|
||||
# This will assume symbol is between [2, inf) at compile time, but [0, inf) at runtime
|
||||
constrain_as_size(symbol, min=0)
|
||||
|
||||
# Case 2
|
||||
# This will assume symbol is between [2, N] at compile time, but [0, N] at runtime
|
||||
constrain_as_size(symbol, min=0, max=N)
|
||||
|
||||
# Case 3
|
||||
# This is not valid case as max is <= 2
|
||||
constrain_as_size(symbol, min=0, max=1)
|
||||
|
||||
# Case 4
|
||||
# This will assume symbol is between [2, inf) at compile time, AND [2, inf) at runtime
|
||||
constrain_as_size(symbol, min=2)
|
||||
|
||||
# Case 5
|
||||
# This will assume symbol is between [2, inf) at compile time, but [1, inf) at runtime
|
||||
constrain_as_size(symbol, min=1)
|
||||
```
|
||||
"""
|
||||
torch.sym_constrain_range_for_size(symbol, min=min, max=max)
|
||||
|
|
|
|||
|
|
@ -806,7 +806,12 @@ class GraphModuleDeserializer:
|
|||
|
||||
if vr := self.symbol_name_to_range.get(val.expr_str):
|
||||
symbolic_shapes._constrain_symbol_range(
|
||||
self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type]
|
||||
self.shape_env,
|
||||
sym,
|
||||
compiler_min=vr.lower, # type: ignore[arg-type]
|
||||
compiler_max=vr.upper, # type: ignore[arg-type]
|
||||
runtime_min=vr.lower, # type: ignore[arg-type]
|
||||
runtime_max=vr.upper # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return self.shape_env.create_symintnode(sym, hint=val.hint)
|
||||
|
|
|
|||
|
|
@ -83,6 +83,11 @@ def functional_assert_async_msg_decomp(tensor, msg):
|
|||
return
|
||||
|
||||
|
||||
@register_decomposition([aten.sym_constrain_range_for_size.default])
|
||||
def sym_constrain_range_for_size(symbol, *, min=None, max=None):
|
||||
return
|
||||
|
||||
|
||||
@register_decomposition([aten.clamp])
|
||||
@pw_cast_for_opmath
|
||||
def clamp(x, min=None, max=None):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
from torch import Tensor
|
||||
from torch import SymBool, SymFloat, Tensor
|
||||
from torch._decomp import (
|
||||
_add_op_to_registry,
|
||||
_convert_out_params,
|
||||
|
|
@ -30,7 +30,10 @@ from torch._prims_common.wrappers import (
|
|||
out_wrapper,
|
||||
)
|
||||
from torch._refs import _broadcast_shapes, _maybe_broadcast
|
||||
from torch.fx.experimental.symbolic_shapes import constrain_range
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
constrain_range,
|
||||
)
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
|
|
@ -424,6 +427,8 @@ def make_dep_token(
|
|||
|
||||
@register_meta(aten.sym_constrain_range.default)
|
||||
def sym_constrain_range(size, min=None, max=None):
|
||||
if isinstance(size, (SymFloat, SymBool)):
|
||||
raise ValueError("Constraining SymFloat or Symbool is nyi")
|
||||
constrain_range(size, min=min, max=max)
|
||||
|
||||
|
||||
|
|
@ -433,6 +438,19 @@ def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
|
|||
return dep_token
|
||||
|
||||
|
||||
@register_meta(aten.sym_constrain_range_for_size.default)
|
||||
def sym_constrain_range_for_size(size, min=None, max=None):
|
||||
if isinstance(size, (SymFloat, SymBool)):
|
||||
raise ValueError("Constraining SymFloat or Symbool is nyi")
|
||||
_constrain_range_for_size(size, min=min, max=max)
|
||||
|
||||
|
||||
@register_meta(aten._functional_sym_constrain_range_for_size.default)
|
||||
def functional_sym_constrain_range_for_size(size, min, max, dep_token):
|
||||
aten.sym_constrain_range_for_size(size, min=min, max=max)
|
||||
return dep_token
|
||||
|
||||
|
||||
@register_meta(aten._functional_assert_async.msg)
|
||||
def functional_assert_async_meta(val, assert_msg, dep_token):
|
||||
return dep_token
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from torch import ( # noqa: F401
|
|||
SymFloat,
|
||||
SymInt,
|
||||
)
|
||||
from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode
|
||||
from torch._guards import ShapeGuard, Source, TracingContext
|
||||
from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift
|
||||
from torch.utils._sympy.solve import try_solve
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
|
||||
|
|
@ -304,13 +304,57 @@ def guard_scalar(a):
|
|||
else:
|
||||
raise AssertionError(f"unrecognized scalar {a}")
|
||||
|
||||
def _constrain_symbol_range(shape_env, s: sympy.Symbol, min: int, max: int):
|
||||
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int, runtime_min: int, runtime_max: int):
|
||||
if r := shape_env.var_to_range.get(s, None):
|
||||
shape_env.var_to_range[s] = ValueRanges(
|
||||
builtins.max(r.lower, min), builtins.min(r.upper, max)
|
||||
builtins.max(r.lower, compiler_min), builtins.min(r.upper, compiler_max)
|
||||
)
|
||||
else:
|
||||
shape_env.var_to_range[s] = ValueRanges(min, max)
|
||||
shape_env.var_to_range[s] = ValueRanges(compiler_min, compiler_max)
|
||||
|
||||
if r := shape_env.runtime_var_to_range.get(s, None):
|
||||
shape_env.runtime_var_to_range[s] = ValueRanges(
|
||||
builtins.max(r.lower, runtime_min), builtins.min(r.upper, runtime_max)
|
||||
)
|
||||
else:
|
||||
shape_env.runtime_var_to_range[s] = ValueRanges(runtime_min, runtime_max)
|
||||
|
||||
def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
|
||||
"""
|
||||
This function is NOT INTENDED to be used by itself.
|
||||
"""
|
||||
|
||||
if isinstance(a, (SymFloat, SymBool)):
|
||||
raise ValueError("Constraining SymFloat/SymBool is nyi")
|
||||
|
||||
assert isinstance(a, SymInt)
|
||||
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||||
|
||||
if min is None:
|
||||
min = 0
|
||||
if max is None:
|
||||
max = sympy.oo
|
||||
|
||||
if max <= 2:
|
||||
raise ValueError(f"Maximum value to constrain_as_size must be greater than 2, but was {max}")
|
||||
|
||||
if max < min:
|
||||
raise ValueError(
|
||||
"Maximum value to constrain_as_size can't be less than the specified min value, "
|
||||
"received min={min} and max={max}"
|
||||
)
|
||||
|
||||
compiler_min = 2 if min < 2 else min
|
||||
|
||||
_constrain_symbol_range(
|
||||
a.node.shape_env,
|
||||
a.node.expr,
|
||||
compiler_min=compiler_min,
|
||||
compiler_max=max,
|
||||
runtime_min=min,
|
||||
runtime_max=max
|
||||
)
|
||||
|
||||
|
||||
# inclusive both ways
|
||||
def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
||||
|
|
@ -350,8 +394,16 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||
min = -sympy.oo
|
||||
if max is None:
|
||||
max = sympy.oo
|
||||
if not isinstance(a, SymInt):
|
||||
constrain_range_int(a, min=min, max=max)
|
||||
|
||||
if max < min:
|
||||
raise ValueError(
|
||||
"Maximum value to constrain_as_size can't be less than the specified min value, "
|
||||
"received min={min} and max={max}"
|
||||
)
|
||||
|
||||
if isinstance(a, int):
|
||||
if not (min <= a <= max):
|
||||
raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
|
||||
return
|
||||
|
||||
if isinstance(a.node.expr, sympy.Integer):
|
||||
|
|
@ -364,35 +416,15 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||
# semantics that this is an "unchecked" assert (but it this actually
|
||||
# something useful? Might be better to restrict only for unbacked
|
||||
# SymInt).
|
||||
_constrain_symbol_range(a.node.shape_env, a.node.expr, min, max)
|
||||
_constrain_symbol_range(
|
||||
a.node.shape_env,
|
||||
a.node.expr,
|
||||
compiler_min=min,
|
||||
compiler_max=max,
|
||||
runtime_min=min,
|
||||
runtime_max=max
|
||||
)
|
||||
|
||||
def constrain_range_int(a, *, min, max):
|
||||
"""
|
||||
Constrain range on concrete int value.
|
||||
This can happens for the following scenarios:
|
||||
- Eager mode execution and real int value is provided.
|
||||
- During tracing the traced symbol is resolved as a static integer (see
|
||||
PR #101655 for more details).
|
||||
"""
|
||||
if min is None:
|
||||
min = -sympy.oo
|
||||
if max is None:
|
||||
max = sympy.oo
|
||||
|
||||
assert not isinstance(a, SymInt)
|
||||
if not (min <= a <= max):
|
||||
raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]")
|
||||
|
||||
if (
|
||||
(fake_mode := detect_fake_mode()) is not None and
|
||||
getattr(fake_mode, "shape_env", None) is not None
|
||||
):
|
||||
# If we are tracing with a fake mode then add this integer to the
|
||||
# shape_env's var_to_range
|
||||
sym_integer = sympy.Integer(a)
|
||||
shape_env = fake_mode.shape_env
|
||||
_constrain_symbol_range(shape_env, sym_integer, min, max)
|
||||
shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack()
|
||||
|
||||
def constrain_unify(a, b):
|
||||
"""
|
||||
|
|
@ -1938,6 +1970,10 @@ class ShapeEnv:
|
|||
# range may contain ints which may not actually appear in
|
||||
# practice
|
||||
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
|
||||
# Maps symbolic ints to their min/max range for runtime checks.
|
||||
# This is because we assume a graph generated with N=2 is general enough
|
||||
# for N < 2. Therefore, it will be too strict to assert N=2 at runtime.
|
||||
self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
|
||||
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
|
||||
self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {}
|
||||
# Maps symbolic ints to the guards that refine their lower/upper
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ _side_effectful_functions: Set[Callable] = {
|
|||
_ops.aten._assert_async.msg,
|
||||
_ops.aten.copy_.default,
|
||||
_ops.aten.sym_constrain_range.default,
|
||||
_ops.aten.sym_constrain_range_for_size.default,
|
||||
_ops.profiler._record_function_enter,
|
||||
_ops.profiler._record_function_enter_new,
|
||||
_ops.profiler._record_function_exit}
|
||||
|
|
|
|||
|
|
@ -186,6 +186,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
torch.sym_min,
|
||||
torch.sym_not,
|
||||
torch.sym_constrain_range,
|
||||
torch.sym_constrain_range_for_size,
|
||||
torch.tril_indices,
|
||||
torch.triu_indices,
|
||||
torch.vander,
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
|||
"record_stream", # no return
|
||||
"sparse_dim", # returns an int
|
||||
"sym_constrain_range", # no return
|
||||
"sym_constrain_range_for_size", # no return
|
||||
"_nested_tensor_storage_offsets", # returns a vector of ints
|
||||
"_chunk_grad_outputs_efficient_attention", # returns a bool
|
||||
"_fused_sdp_choice", # returns an int
|
||||
|
|
|
|||
Loading…
Reference in a new issue