Revert "[hop] ban creating hop by directly instantiating HigherOrderOperator. (#133645)"

This reverts commit 696107efcb.

Reverted https://github.com/pytorch/pytorch/pull/133645 on behalf of https://github.com/ydwu4 due to breaking ci. probably due to land race ([comment](https://github.com/pytorch/pytorch/pull/133645#issuecomment-2302866106))
This commit is contained in:
PyTorch MergeBot 2024-08-21 19:33:14 +00:00
parent 5fcfccefc6
commit 1491a61769
11 changed files with 11 additions and 68 deletions

View file

@ -6299,11 +6299,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
self._validate(fn, backend, x)
def test_override_fallthrough_dispatch_key(self):
class _FallthroughTestOnly(torch._ops.HigherOrderOperator):
def __init__(self):
super().__init__("_fallthrough_test_only")
test_op = _FallthroughTestOnly()
test_op = torch._ops.HigherOrderOperator("_fallthrough_test_only")
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
self.assertTrue(
not any(test_op.non_fallthrough_keys.has(key) for key in default_keys)

View file

@ -4989,11 +4989,7 @@ def forward(self, x_1):
def construct_sum_pyop():
class MySum(HigherOrderOperator):
def __init__(self):
super().__init__("mysum")
mysum = MySum()
mysum = HigherOrderOperator("mysum")
@mysum.py_impl(torch._C._functorch.TransformType.Vmap)
def mysum_batch_rule(interpreter, x, dim):

View file

@ -48,13 +48,8 @@ def trace_wrapped(*args, **kwargs):
return _trace_wrapped_op(*args, **kwargs)
class TraceWrapped(HigherOrderOperator):
def __init__(self):
super().__init__("trace_wrapped")
# TODO(jansel): need to ensure this does not get DCEed
_trace_wrapped_op = TraceWrapped()
_trace_wrapped_op = HigherOrderOperator("trace_wrapped")
def _assert_meta(grad, size, stride, dtype):

View file

@ -12,12 +12,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
from torch.utils import _pytree as pytree
class ExportTracepoint(HigherOrderOperator):
def __init__(self):
super().__init__("export_tracepoint")
_export_tracepoint = ExportTracepoint()
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)

View file

@ -25,12 +25,7 @@ from torch.fx.experimental.proxy_tensor import (
from torch.utils._pytree import tree_flatten
class ExecutorchCallDelegate(HigherOrderOperator):
def __init__(self):
super().__init__("executorch_call_delegate")
executorch_call_delegate = ExecutorchCallDelegate()
executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)

View file

@ -36,14 +36,8 @@ class MapWrapper(HigherOrderOperator):
return map_wrapper(xs, *args)
class MapImpl(HigherOrderOperator):
def __init__(self):
super().__init__("map_impl")
map = MapWrapper("map")
map_impl = MapImpl()
map_impl = HigherOrderOperator("map_impl")
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]

View file

@ -8,12 +8,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
from torch.utils import _pytree as pytree
class RunConstGraph(HigherOrderOperator):
def __init__(self):
super().__init__("run_const_graph")
run_const_graph = RunConstGraph()
run_const_graph = HigherOrderOperator("run_const_graph")
@run_const_graph.py_impl(ProxyTorchDispatchMode)

View file

@ -28,12 +28,7 @@ def strict_mode(callable, operands):
)
class StrictMode(HigherOrderOperator):
def __init__(self):
super().__init__("strict_mode")
strict_mode_op = StrictMode()
strict_mode_op = HigherOrderOperator("strict_mode")
@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)

View file

@ -16,18 +16,12 @@ from torch.utils import _pytree as pytree
log = logging.getLogger(__name__)
# The call_torchbind operator represents a method invocation on a torchbind
# object. The calling convention is:
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
# We do not expect users to write this operator directly. Instead it will be
# emitted by Dynamo when tracing encounters a torchbind object.
class CallTorchBind(HigherOrderOperator):
def __init__(self):
super().__init__("call_torchbind")
call_torchbind = CallTorchBind()
call_torchbind = HigherOrderOperator("call_torchbind")
# Register this operator as side-effectful with FX.
# TODO: this is not really sufficient. While passes (hopefully) check

View file

@ -246,10 +246,6 @@ class HigherOrderOperator(OperatorBase):
# practice due to name collisions.
def __init__(self, name):
super().__init__()
if type(self) is HigherOrderOperator:
raise RuntimeError(
"Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
)
self._name = name
# Make _OPNamespace not scream, this whole name based association needs a good hard look

View file

@ -148,11 +148,7 @@ def get_device(args, kwargs):
def register_run_and_save_rng_state_op():
class RunAndSaveRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_and_save_rng_state")
run_and_save_rng_state = RunAndSaveRngState()
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
@ -194,11 +190,7 @@ def register_run_and_save_rng_state_op():
def register_run_with_rng_state_op():
class RunWithRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_with_rng_state")
run_with_rng_state = RunWithRngState()
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
run_with_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_with_rng_state, deferred_error=True)