mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Persist torch.assert in aten graph (#100101)
This PR introduces a new operator called aten._assert_async.msg, which allows passing a tensor value and assertion message as inputs. As part of TorchDynamo, we're replacing the use of torch._assert with this new operator so that make_fx also knows how to handle assertions. This is subset of https://github.com/pytorch/pytorch/pull/98878, refer there for historic reviews. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100101 Approved by: https://github.com/jansel
This commit is contained in:
parent
cef15ecc2e
commit
d4bf76c2a4
12 changed files with 100 additions and 7 deletions
|
|
@ -405,6 +405,10 @@ void _assert_async_cpu(const Tensor& self) {
|
|||
TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
|
||||
}
|
||||
|
||||
void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
|
||||
TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
|
||||
}
|
||||
|
||||
// Sorting-based algorithm for isin(); used when the number of test elements is large.
|
||||
static void isin_sorting(
|
||||
const Tensor& elements,
|
||||
|
|
|
|||
|
|
@ -170,6 +170,9 @@
|
|||
CPU: _assert_async_cpu
|
||||
CUDA: _assert_async_cuda
|
||||
|
||||
- func: _assert_async.msg(Tensor self, str assert_msg) -> ()
|
||||
dispatch:
|
||||
CPU: _assert_async_msg_cpu
|
||||
|
||||
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
|
||||
|
||||
|
|
|
|||
|
|
@ -2583,6 +2583,31 @@ def forward(self, x):
|
|||
):
|
||||
gm, _ = torch._dynamo.export(f, torch.randn(5, 6), aten_graph=True)
|
||||
|
||||
@config.patch(assume_static_by_default=False)
|
||||
def test_export_persist_assert(self):
|
||||
def f(x):
|
||||
assert x.shape[0] > 4, "Shape must be more than 4"
|
||||
return x.cos() + x.sin()
|
||||
|
||||
gm, guard = torch._dynamo.export(
|
||||
f, torch.randn(5, 4, 6), aten_graph=True, tracing_mode="symbolic"
|
||||
)
|
||||
|
||||
def has_aten_op(gm, op):
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == op:
|
||||
return True
|
||||
return False
|
||||
|
||||
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
|
||||
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
|
||||
gm(torch.randn(3, 4, 5))
|
||||
|
||||
def test_access_class_method_from_user_class(self):
|
||||
class A:
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -2496,7 +2496,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
|
||||
self.assertTrue(same(exported(*args), f(*args)))
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, ""):
|
||||
with self.assertRaisesRegex(RuntimeError, "First dim need to be 3"):
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
|
||||
|
||||
def test_not_rewrite_assert_for_other_errors(self):
|
||||
|
|
@ -2521,7 +2521,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
|
||||
self.assertTrue(same(exported(*args), f(*args)))
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, ""):
|
||||
with self.assertRaisesRegex(RuntimeError, "assertion error"):
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
|
||||
|
||||
def test_rewrite_assert_with_non_string_msg(self):
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ aten::_amp_update_scale
|
|||
aten::_amp_update_scale.out
|
||||
aten::_amp_update_scale_
|
||||
aten::_assert_async
|
||||
aten::_assert_async.msg
|
||||
aten::_cdist_backward
|
||||
aten::_cdist_backward.out
|
||||
aten::_cdist_forward
|
||||
|
|
|
|||
|
|
@ -1901,6 +1901,17 @@ class CommonTemplate:
|
|||
with self.assertRaisesRegex(RuntimeError, ""):
|
||||
fn(torch.randn(1, 5))
|
||||
|
||||
def test_inductor_assert(self):
|
||||
@torch._dynamo.optimize("inductor", dynamic=True)
|
||||
def fn(a):
|
||||
assert a.shape[0] >= 2 and a.shape[1] >= 4
|
||||
return a.cos()
|
||||
|
||||
inp = torch.randn(2, 4, 6)
|
||||
torch._dynamo.mark_dynamic(inp, 0)
|
||||
torch._dynamo.mark_dynamic(inp, 1)
|
||||
self.assertEqual(fn(inp), inp.cos())
|
||||
|
||||
def test_split(self):
|
||||
def fn(a):
|
||||
t = torch.split(a, 3, -1)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,13 @@ from .source import (
|
|||
GlobalWeakRefSource,
|
||||
LocalSource,
|
||||
)
|
||||
from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
|
||||
from .utils import (
|
||||
counters,
|
||||
get_fake_value,
|
||||
graph_break_dup_warning_checker,
|
||||
istype,
|
||||
proxy_args_kwargs,
|
||||
)
|
||||
from .variables.base import MutableLocal, typestr, VariableTracker
|
||||
from .variables.builder import VariableBuilder, wrap_fx_proxy
|
||||
from .variables.builtin import BuiltinVariable
|
||||
|
|
@ -249,12 +255,35 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|||
self.jump(inst)
|
||||
return
|
||||
|
||||
# Manually insert torch._assert instead of python assert and jump over
|
||||
# TODO maybe should respect DtoH sync intention of users later??
|
||||
# Manually insert torch._assert_async instead of python assert and jump over
|
||||
# assert related instructions as we don't need them anymore.
|
||||
|
||||
# if we see Tensor as assert statement, no need to call scalar_tensor
|
||||
if isinstance(value, TensorVariable):
|
||||
self.output.create_proxy(
|
||||
"call_function",
|
||||
torch._assert_async,
|
||||
*proxy_args_kwargs((value, error_msg), {}),
|
||||
)
|
||||
self.jump(inst)
|
||||
return
|
||||
|
||||
scalar_to_tensor_proxy = self.output.create_proxy(
|
||||
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
|
||||
)
|
||||
|
||||
scalar_to_tensor = wrap_fx_proxy(
|
||||
self,
|
||||
scalar_to_tensor_proxy,
|
||||
example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
|
||||
**VariableTracker.propagate([value]),
|
||||
)
|
||||
|
||||
self.output.create_proxy(
|
||||
"call_function",
|
||||
torch._assert,
|
||||
*proxy_args_kwargs((value, error_msg), {}),
|
||||
torch._assert_async,
|
||||
*proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
|
||||
)
|
||||
self.jump(inst)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -54,6 +54,13 @@ def _unsafe_view(self, size):
|
|||
return self.view(size)
|
||||
|
||||
|
||||
# TODO: for now, inductor doesn't handle asserts
|
||||
# because the condition is symbool -> tensor in the graph.
|
||||
@register_decomposition([aten._assert_async.msg])
|
||||
def assert_async_msg_decomp(tensor, msg):
|
||||
return
|
||||
|
||||
|
||||
@register_decomposition([aten.clamp])
|
||||
@pw_cast_for_opmath
|
||||
def clamp(x, min=None, max=None):
|
||||
|
|
|
|||
|
|
@ -295,6 +295,16 @@ def meta_angle_out(self, out):
|
|||
return out.copy_(torch.angle(self))
|
||||
|
||||
|
||||
@register_meta(aten._assert_async.default)
|
||||
def assert_async(val):
|
||||
return
|
||||
|
||||
|
||||
@register_meta(aten._assert_async.msg)
|
||||
def assert_async_meta(val, assert_msg):
|
||||
return
|
||||
|
||||
|
||||
# From aten/src/ATen/native/LinearAlgebraUtils.h
|
||||
def squareCheckInputs(self: Tensor, f_name: str):
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ Argument = Optional[Union[
|
|||
|
||||
_side_effectful_functions: Set[Callable] = {
|
||||
torch._assert,
|
||||
torch._assert_async,
|
||||
_ops.aten._assert_async.msg,
|
||||
_ops.aten.copy_.default,
|
||||
_ops.profiler._record_function_enter,
|
||||
_ops.profiler._record_function_enter_new,
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.argmin: lambda input: -1,
|
||||
torch.argsort: lambda input, dim=None: -1,
|
||||
torch.asin: lambda input, out=None: -1,
|
||||
torch._assert_async: lambda input: -1,
|
||||
torch._assert_async: lambda input, msg: -1,
|
||||
torch.arcsin: lambda input, out=None: -1,
|
||||
torch.asinh: lambda input, out=None: -1,
|
||||
torch.arcsinh: lambda input, out=None: -1,
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
|||
# All of these operators don't have any tensor like returns
|
||||
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
||||
"_assert_async", # no return
|
||||
"_assert_async.msg", # no return
|
||||
"_dimI", # returns an int
|
||||
"_dimV", # returns an int
|
||||
"_has_same_storage_numel", # returns a boolean
|
||||
|
|
|
|||
Loading…
Reference in a new issue