Persist torch.assert in aten graph (#100101)

This PR introduces a new operator called aten._assert_async.msg, which allows passing a tensor value and assertion message as inputs. As part of TorchDynamo, we're replacing the use of torch._assert with this new operator so that make_fx also knows how to handle assertions. This is subset of https://github.com/pytorch/pytorch/pull/98878, refer there for historic reviews.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100101
Approved by: https://github.com/jansel
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2023-04-27 06:53:51 -07:00 committed by PyTorch MergeBot
parent cef15ecc2e
commit d4bf76c2a4
12 changed files with 100 additions and 7 deletions

View file

@ -405,6 +405,10 @@ void _assert_async_cpu(const Tensor& self) {
TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
}
void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
}
// Sorting-based algorithm for isin(); used when the number of test elements is large.
static void isin_sorting(
const Tensor& elements,

View file

@ -170,6 +170,9 @@
CPU: _assert_async_cpu
CUDA: _assert_async_cuda
- func: _assert_async.msg(Tensor self, str assert_msg) -> ()
dispatch:
CPU: _assert_async_msg_cpu
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()

View file

@ -2583,6 +2583,31 @@ def forward(self, x):
):
gm, _ = torch._dynamo.export(f, torch.randn(5, 6), aten_graph=True)
@config.patch(assume_static_by_default=False)
def test_export_persist_assert(self):
def f(x):
assert x.shape[0] > 4, "Shape must be more than 4"
return x.cos() + x.sin()
gm, guard = torch._dynamo.export(
f, torch.randn(5, 4, 6), aten_graph=True, tracing_mode="symbolic"
)
def has_aten_op(gm, op):
for node in gm.graph.nodes:
if node.target == op:
return True
return False
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
gm.graph.eliminate_dead_code()
gm.recompile()
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
gm(torch.randn(3, 4, 5))
def test_access_class_method_from_user_class(self):
class A:
@classmethod

View file

@ -2496,7 +2496,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(RuntimeError, "First dim need to be 3"):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
def test_not_rewrite_assert_for_other_errors(self):
@ -2521,7 +2521,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(RuntimeError, "assertion error"):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
def test_rewrite_assert_with_non_string_msg(self):

View file

@ -35,6 +35,7 @@ aten::_amp_update_scale
aten::_amp_update_scale.out
aten::_amp_update_scale_
aten::_assert_async
aten::_assert_async.msg
aten::_cdist_backward
aten::_cdist_backward.out
aten::_cdist_forward

View file

@ -1901,6 +1901,17 @@ class CommonTemplate:
with self.assertRaisesRegex(RuntimeError, ""):
fn(torch.randn(1, 5))
def test_inductor_assert(self):
@torch._dynamo.optimize("inductor", dynamic=True)
def fn(a):
assert a.shape[0] >= 2 and a.shape[1] >= 4
return a.cos()
inp = torch.randn(2, 4, 6)
torch._dynamo.mark_dynamic(inp, 0)
torch._dynamo.mark_dynamic(inp, 1)
self.assertEqual(fn(inp), inp.cos())
def test_split(self):
def fn(a):
t = torch.split(a, 3, -1)

View file

@ -55,7 +55,13 @@ from .source import (
GlobalWeakRefSource,
LocalSource,
)
from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
from .utils import (
counters,
get_fake_value,
graph_break_dup_warning_checker,
istype,
proxy_args_kwargs,
)
from .variables.base import MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder, wrap_fx_proxy
from .variables.builtin import BuiltinVariable
@ -249,12 +255,35 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
self.jump(inst)
return
# Manually insert torch._assert instead of python assert and jump over
# TODO maybe should respect DtoH sync intention of users later??
# Manually insert torch._assert_async instead of python assert and jump over
# assert related instructions as we don't need them anymore.
# if we see Tensor as assert statement, no need to call scalar_tensor
if isinstance(value, TensorVariable):
self.output.create_proxy(
"call_function",
torch._assert_async,
*proxy_args_kwargs((value, error_msg), {}),
)
self.jump(inst)
return
scalar_to_tensor_proxy = self.output.create_proxy(
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
)
scalar_to_tensor = wrap_fx_proxy(
self,
scalar_to_tensor_proxy,
example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
**VariableTracker.propagate([value]),
)
self.output.create_proxy(
"call_function",
torch._assert,
*proxy_args_kwargs((value, error_msg), {}),
torch._assert_async,
*proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
)
self.jump(inst)
return

View file

@ -54,6 +54,13 @@ def _unsafe_view(self, size):
return self.view(size)
# TODO: for now, inductor doesn't handle asserts
# because the condition is symbool -> tensor in the graph.
@register_decomposition([aten._assert_async.msg])
def assert_async_msg_decomp(tensor, msg):
return
@register_decomposition([aten.clamp])
@pw_cast_for_opmath
def clamp(x, min=None, max=None):

View file

@ -295,6 +295,16 @@ def meta_angle_out(self, out):
return out.copy_(torch.angle(self))
@register_meta(aten._assert_async.default)
def assert_async(val):
return
@register_meta(aten._assert_async.msg)
def assert_async_meta(val, assert_msg):
return
# From aten/src/ATen/native/LinearAlgebraUtils.h
def squareCheckInputs(self: Tensor, f_name: str):
assert (

View file

@ -32,6 +32,8 @@ Argument = Optional[Union[
_side_effectful_functions: Set[Callable] = {
torch._assert,
torch._assert_async,
_ops.aten._assert_async.msg,
_ops.aten.copy_.default,
_ops.profiler._record_function_enter,
_ops.profiler._record_function_enter_new,

View file

@ -387,7 +387,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.argmin: lambda input: -1,
torch.argsort: lambda input, dim=None: -1,
torch.asin: lambda input, out=None: -1,
torch._assert_async: lambda input: -1,
torch._assert_async: lambda input, msg: -1,
torch.arcsin: lambda input, out=None: -1,
torch.asinh: lambda input, out=None: -1,
torch.arcsinh: lambda input, out=None: -1,

View file

@ -49,6 +49,7 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
# All of these operators don't have any tensor like returns
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"_assert_async", # no return
"_assert_async.msg", # no return
"_dimI", # returns an int
"_dimV", # returns an int
"_has_same_storage_numel", # returns a boolean