[HOP] Don't send HOPs to torch_dispatch (#131370)

I regretted the decision in
https://github.com/pytorch/pytorch/pull/130606. Most user
torch_dispatchs don't have enough to actually handle the HOP correctly,
so for now I'd prefer that users explicitly define the interaction
between the HOP and their torch_dispatch class.

An example is FlopCounterMode: if we allow HOPs to get passed to it, it
will ignore auto_functionalized(mm) by default but it will record flops
for mm, which is weird.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131370
Approved by: https://github.com/ydwu4
This commit is contained in:
rzou 2024-07-22 13:57:48 -07:00 committed by PyTorch MergeBot
parent 027f35d9e6
commit 4ac77fc6bd
2 changed files with 12 additions and 13 deletions

View file

@ -546,9 +546,8 @@ class GraphModule(torch.nn.Module):
a = torch.tensor([1.0, 0.0, 1.0])
b = torch.randn(3)
t = TwoTensor(a, b)
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
self.assertEqual(res.a, torch.sin(a))
self.assertEqual(res.b, torch.sin(b))
with self.assertRaisesRegex(NotImplementedError, "no rule registered"):
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
called = 0
@ -580,10 +579,9 @@ class GraphModule(torch.nn.Module):
a = torch.tensor([1.0, 0.1, 1.0])
pred = a.sum() > 0
with MyMode():
res = cond_op(pred, torch.sin, torch.cos, (a,))
self.assertEqual(res, a.sin())
self.assertEqual(torch_dispatch_called, 1)
with self.assertRaisesRegex(NotImplementedError, "no rule registered"):
with MyMode():
res = cond_op(pred, torch.sin, torch.cos, (a,))
py_impl_called = 0

View file

@ -336,10 +336,10 @@ class HigherOrderOperator(OperatorBase):
# TODO(rzou): we should support torch_dispatch calling convention too.
result = handler(mode, *args, **kwargs)
else:
with _pop_mode_temporarily() as mode:
result = curr_mode.__torch_dispatch__(
self_, overloaded_types, args, kwargs
)
raise NotImplementedError(
"There was no rule registered for HOP {self._name} and mode {curr_mode}. "
"We recommend filing an issue."
)
if result is not NotImplemented:
return result
@ -357,8 +357,9 @@ class HigherOrderOperator(OperatorBase):
# TODO(rzou): we should support torch_dispatch calling convention too.
result = handler(*args, **kwargs)
else:
result = subclass_type.__torch_dispatch__(
self_, overloaded_types, args, kwargs
raise NotImplementedError(
"There was no rule registered for HOP {self._name} and subclass {subclass_type}. "
"We recommend filing an issue."
)
if result is not NotImplemented:
return result