mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[HOP] Don't send HOPs to torch_dispatch (#131370)
I regretted the decision in https://github.com/pytorch/pytorch/pull/130606. Most user torch_dispatchs don't have enough to actually handle the HOP correctly, so for now I'd prefer that users explicitly define the interaction between the HOP and their torch_dispatch class. An example is FlopCounterMode: if we allow HOPs to get passed to it, it will ignore auto_functionalized(mm) by default but it will record flops for mm, which is weird. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/131370 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
027f35d9e6
commit
4ac77fc6bd
2 changed files with 12 additions and 13 deletions
|
|
@ -546,9 +546,8 @@ class GraphModule(torch.nn.Module):
|
|||
a = torch.tensor([1.0, 0.0, 1.0])
|
||||
b = torch.randn(3)
|
||||
t = TwoTensor(a, b)
|
||||
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
|
||||
self.assertEqual(res.a, torch.sin(a))
|
||||
self.assertEqual(res.b, torch.sin(b))
|
||||
with self.assertRaisesRegex(NotImplementedError, "no rule registered"):
|
||||
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
|
||||
|
||||
called = 0
|
||||
|
||||
|
|
@ -580,10 +579,9 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
a = torch.tensor([1.0, 0.1, 1.0])
|
||||
pred = a.sum() > 0
|
||||
with MyMode():
|
||||
res = cond_op(pred, torch.sin, torch.cos, (a,))
|
||||
self.assertEqual(res, a.sin())
|
||||
self.assertEqual(torch_dispatch_called, 1)
|
||||
with self.assertRaisesRegex(NotImplementedError, "no rule registered"):
|
||||
with MyMode():
|
||||
res = cond_op(pred, torch.sin, torch.cos, (a,))
|
||||
|
||||
py_impl_called = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -336,10 +336,10 @@ class HigherOrderOperator(OperatorBase):
|
|||
# TODO(rzou): we should support torch_dispatch calling convention too.
|
||||
result = handler(mode, *args, **kwargs)
|
||||
else:
|
||||
with _pop_mode_temporarily() as mode:
|
||||
result = curr_mode.__torch_dispatch__(
|
||||
self_, overloaded_types, args, kwargs
|
||||
)
|
||||
raise NotImplementedError(
|
||||
"There was no rule registered for HOP {self._name} and mode {curr_mode}. "
|
||||
"We recommend filing an issue."
|
||||
)
|
||||
if result is not NotImplemented:
|
||||
return result
|
||||
|
||||
|
|
@ -357,8 +357,9 @@ class HigherOrderOperator(OperatorBase):
|
|||
# TODO(rzou): we should support torch_dispatch calling convention too.
|
||||
result = handler(*args, **kwargs)
|
||||
else:
|
||||
result = subclass_type.__torch_dispatch__(
|
||||
self_, overloaded_types, args, kwargs
|
||||
raise NotImplementedError(
|
||||
"There was no rule registered for HOP {self._name} and subclass {subclass_type}. "
|
||||
"We recommend filing an issue."
|
||||
)
|
||||
if result is not NotImplemented:
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in a new issue