diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 89fa1dc2913..8fab77c3996 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -6299,11 +6299,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): self._validate(fn, backend, x) def test_override_fallthrough_dispatch_key(self): - class _FallthroughTestOnly(torch._ops.HigherOrderOperator): - def __init__(self): - super().__init__("_fallthrough_test_only") - - test_op = _FallthroughTestOnly() + test_op = torch._ops.HigherOrderOperator("_fallthrough_test_only") default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS self.assertTrue( not any(test_op.non_fallthrough_keys.has(key) for key in default_keys) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 97814ddbff5..01dfcf87188 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4989,11 +4989,7 @@ def forward(self, x_1): def construct_sum_pyop(): - class MySum(HigherOrderOperator): - def __init__(self): - super().__init__("mysum") - - mysum = MySum() + mysum = HigherOrderOperator("mysum") @mysum.py_impl(torch._C._functorch.TransformType.Vmap) def mysum_batch_rule(interpreter, x, dim): diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 38f763f0789..088aa15a4e2 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -48,13 +48,8 @@ def trace_wrapped(*args, **kwargs): return _trace_wrapped_op(*args, **kwargs) -class TraceWrapped(HigherOrderOperator): - def __init__(self): - super().__init__("trace_wrapped") - - # TODO(jansel): need to ensure this does not get DCEed -_trace_wrapped_op = TraceWrapped() +_trace_wrapped_op = HigherOrderOperator("trace_wrapped") def _assert_meta(grad, size, stride, dtype): diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index 847cc8b3e13..d0008e5168d 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -12,12 +12,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten from torch.utils import _pytree as pytree -class ExportTracepoint(HigherOrderOperator): - def __init__(self): - super().__init__("export_tracepoint") - - -_export_tracepoint = ExportTracepoint() +_export_tracepoint = HigherOrderOperator("_export_tracepoint") @_export_tracepoint.py_impl(ProxyTorchDispatchMode) diff --git a/torch/_higher_order_ops/executorch_call_delegate.py b/torch/_higher_order_ops/executorch_call_delegate.py index 4ef87f36080..e16401d1f44 100644 --- a/torch/_higher_order_ops/executorch_call_delegate.py +++ b/torch/_higher_order_ops/executorch_call_delegate.py @@ -25,12 +25,7 @@ from torch.fx.experimental.proxy_tensor import ( from torch.utils._pytree import tree_flatten -class ExecutorchCallDelegate(HigherOrderOperator): - def __init__(self): - super().__init__("executorch_call_delegate") - - -executorch_call_delegate = ExecutorchCallDelegate() +executorch_call_delegate = HigherOrderOperator("executorch_call_delegate") executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index b91a7a3701d..cb66648589f 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -36,14 +36,8 @@ class MapWrapper(HigherOrderOperator): return map_wrapper(xs, *args) -class MapImpl(HigherOrderOperator): - def __init__(self): - super().__init__("map_impl") - - map = MapWrapper("map") - -map_impl = MapImpl() +map_impl = HigherOrderOperator("map_impl") dummy_aot_config = AOTConfig( fw_compiler=None, # type: ignore[arg-type] diff --git a/torch/_higher_order_ops/run_const_graph.py b/torch/_higher_order_ops/run_const_graph.py index 774f3ebc0e2..f200c75a6ba 100644 --- a/torch/_higher_order_ops/run_const_graph.py +++ b/torch/_higher_order_ops/run_const_graph.py @@ -8,12 +8,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten from torch.utils import _pytree as pytree -class RunConstGraph(HigherOrderOperator): - def __init__(self): - super().__init__("run_const_graph") - - -run_const_graph = RunConstGraph() +run_const_graph = HigherOrderOperator("run_const_graph") @run_const_graph.py_impl(ProxyTorchDispatchMode) diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index c58eea57506..3f961bf3ec7 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -28,12 +28,7 @@ def strict_mode(callable, operands): ) -class StrictMode(HigherOrderOperator): - def __init__(self): - super().__init__("strict_mode") - - -strict_mode_op = StrictMode() +strict_mode_op = HigherOrderOperator("strict_mode") @strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index ef2f228bc4b..03fc5791dd5 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -16,18 +16,12 @@ from torch.utils import _pytree as pytree log = logging.getLogger(__name__) - # The call_torchbind operator represents a method invocation on a torchbind # object. The calling convention is: # call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) # We do not expect users to write this operator directly. Instead it will be # emitted by Dynamo when tracing encounters a torchbind object. -class CallTorchBind(HigherOrderOperator): - def __init__(self): - super().__init__("call_torchbind") - - -call_torchbind = CallTorchBind() +call_torchbind = HigherOrderOperator("call_torchbind") # Register this operator as side-effectful with FX. # TODO: this is not really sufficient. While passes (hopefully) check diff --git a/torch/_ops.py b/torch/_ops.py index 08d0ffab5e7..a2f35c9ef41 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -246,10 +246,6 @@ class HigherOrderOperator(OperatorBase): # practice due to name collisions. def __init__(self, name): super().__init__() - if type(self) is HigherOrderOperator: - raise RuntimeError( - "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it." - ) self._name = name # Make _OPNamespace not scream, this whole name based association needs a good hard look diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index f94a0d06ad6..cf181c64117 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -148,11 +148,7 @@ def get_device(args, kwargs): def register_run_and_save_rng_state_op(): - class RunAndSaveRngState(HigherOrderOperator): - def __init__(self): - super().__init__("run_and_save_rng_state") - - run_and_save_rng_state = RunAndSaveRngState() + run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state") run_and_save_rng_state.py_impl(DispatchKey.Autograd)( autograd_not_implemented(run_and_save_rng_state, deferred_error=True) @@ -194,11 +190,7 @@ def register_run_and_save_rng_state_op(): def register_run_with_rng_state_op(): - class RunWithRngState(HigherOrderOperator): - def __init__(self): - super().__init__("run_with_rng_state") - - run_with_rng_state = RunWithRngState() + run_with_rng_state = HigherOrderOperator("run_with_rng_state") run_with_rng_state.py_impl(DispatchKey.Autograd)( autograd_not_implemented(run_with_rng_state, deferred_error=True)