From 4ac77fc6bd541b6566db33fb9e60d2cda353422c Mon Sep 17 00:00:00 2001 From: rzou Date: Mon, 22 Jul 2024 13:57:48 -0700 Subject: [PATCH] [HOP] Don't send HOPs to torch_dispatch (#131370) I regretted the decision in https://github.com/pytorch/pytorch/pull/130606. Most user torch_dispatchs don't have enough to actually handle the HOP correctly, so for now I'd prefer that users explicitly define the interaction between the HOP and their torch_dispatch class. An example is FlopCounterMode: if we allow HOPs to get passed to it, it will ignore auto_functionalized(mm) by default but it will record flops for mm, which is weird. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/131370 Approved by: https://github.com/ydwu4 --- test/dynamo/test_higher_order_ops.py | 12 +++++------- torch/_ops.py | 13 +++++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 9fce5203486..8637fe26d9e 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -546,9 +546,8 @@ class GraphModule(torch.nn.Module): a = torch.tensor([1.0, 0.0, 1.0]) b = torch.randn(3) t = TwoTensor(a, b) - res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,)) - self.assertEqual(res.a, torch.sin(a)) - self.assertEqual(res.b, torch.sin(b)) + with self.assertRaisesRegex(NotImplementedError, "no rule registered"): + res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,)) called = 0 @@ -580,10 +579,9 @@ class GraphModule(torch.nn.Module): a = torch.tensor([1.0, 0.1, 1.0]) pred = a.sum() > 0 - with MyMode(): - res = cond_op(pred, torch.sin, torch.cos, (a,)) - self.assertEqual(res, a.sin()) - self.assertEqual(torch_dispatch_called, 1) + with self.assertRaisesRegex(NotImplementedError, "no rule registered"): + with MyMode(): + res = cond_op(pred, torch.sin, torch.cos, (a,)) py_impl_called = 0 diff --git a/torch/_ops.py b/torch/_ops.py index b9357e16e31..d8668138430 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -336,10 +336,10 @@ class HigherOrderOperator(OperatorBase): # TODO(rzou): we should support torch_dispatch calling convention too. result = handler(mode, *args, **kwargs) else: - with _pop_mode_temporarily() as mode: - result = curr_mode.__torch_dispatch__( - self_, overloaded_types, args, kwargs - ) + raise NotImplementedError( + "There was no rule registered for HOP {self._name} and mode {curr_mode}. " + "We recommend filing an issue." + ) if result is not NotImplemented: return result @@ -357,8 +357,9 @@ class HigherOrderOperator(OperatorBase): # TODO(rzou): we should support torch_dispatch calling convention too. result = handler(*args, **kwargs) else: - result = subclass_type.__torch_dispatch__( - self_, overloaded_types, args, kwargs + raise NotImplementedError( + "There was no rule registered for HOP {self._name} and subclass {subclass_type}. " + "We recommend filing an issue." ) if result is not NotImplemented: return result