From d29094888ba28000c529ab4f5871d5287b68e3b8 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 2 Oct 2024 09:13:19 +0000 Subject: [PATCH] Use torch.Stream&torch.Event for Dynamo capature (#134850) # Motivation This PR aims to solve the multiple Inheritance problem. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134850 Approved by: https://github.com/yf225, https://github.com/EikanWang --- .../triton/device_interface.py | 5 +- torch/_dynamo/device_interface.py | 36 +++++------- torch/_dynamo/variables/builder.py | 12 ++-- torch/_dynamo/variables/torch.py | 3 +- torch/_streambase.py | 58 +++++-------------- torch/cuda/streams.py | 5 +- torch/xpu/streams.py | 8 +-- 7 files changed, 46 insertions(+), 81 deletions(-) diff --git a/test/inductor/extension_backends/triton/device_interface.py b/test/inductor/extension_backends/triton/device_interface.py index c7cabf31dc6..9ca96e71a7d 100644 --- a/test/inductor/extension_backends/triton/device_interface.py +++ b/test/inductor/extension_backends/triton/device_interface.py @@ -2,6 +2,7 @@ from __future__ import annotations import time +import torch from torch._dynamo import device_interface # noqa: PLC2701 import-private-name @@ -13,9 +14,7 @@ class DeviceProperties: class DeviceInterface(device_interface.DeviceInterface): - class Event( - device_interface._EventBase - ): # pyright: ignore [reportPrivateImportUsage] + class Event(torch.Event): def __init__( self, enable_timing: bool = False, diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index ad58a11bbd5..baa26c64789 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,11 +1,9 @@ # mypy: allow-untyped-defs -import inspect import time from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union import torch -from torch._streambase import _EventBase, _StreamBase get_cuda_stream: Optional[Callable[[int], int]] @@ -21,21 +19,7 @@ caching_worker_device_properties: Dict[str, Any] = {} caching_worker_current_devices: Dict[str, int] = {} -class DeviceInterfaceMeta(type): - def __new__(metacls, *args, **kwargs): - class_member = args[2] - if "Event" in class_member: - assert inspect.isclass(class_member["Event"]) and issubclass( - class_member["Event"], _EventBase - ), "DeviceInterface member Event should be inherit from _EventBase" - if "Stream" in class_member: - assert inspect.isclass(class_member["Stream"]) and issubclass( - class_member["Stream"], _StreamBase - ), "DeviceInterface member Stream should be inherit from _StreamBase" - return super().__new__(metacls, *args, **kwargs) - - -class DeviceInterface(metaclass=DeviceInterfaceMeta): +class DeviceInterface: """ This is a simple device runtime interface for Inductor. It enables custom backends to be integrated with Inductor in a device-agnostic semantic. @@ -45,6 +29,18 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta): def __new__(cls, device: _device_t): raise NotImplementedError + class Event: + def __new__(cls, *args, **kwargs): + raise NotImplementedError( + "Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo." + ) + + class Stream: + def __new__(cls, *args, **kwargs): + raise NotImplementedError( + "Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo." + ) + class Worker: """ Worker API to query device properties that will work in multi processing @@ -161,7 +157,7 @@ class CudaInterface(DeviceInterface): device = torch.cuda.device # register Event and Stream class into the backend interface - # make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase + # make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream Event = torch.cuda.Event Stream = torch.cuda.Stream @@ -303,14 +299,14 @@ class CpuDeviceProperties: class CpuInterface(DeviceInterface): - class Event(_EventBase): + class Event(torch.Event): def __init__(self, enable_timing=True): self.time = 0.0 def elapsed_time(self, end_event) -> float: return (end_event.time - self.time) * 1000 - def record(self): + def record(self, stream=None): self.time = time.perf_counter() @staticmethod diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 8f37e9a9d1d..e9be323bce0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -36,7 +36,6 @@ from torch import SymInt from torch._guards import GuardSource, TracingContext from torch._higher_order_ops.torchbind import call_torchbind from torch._ops import HigherOrderOperator -from torch._streambase import _EventBase, _StreamBase from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch._subclasses.meta_utils import is_sparse_any, safe_grad from torch._utils_internal import justknobs_check @@ -822,7 +821,7 @@ class VariableBuilder: stream_source = AttrSource(self.source, "stream") stream_var = VariableBuilder(self.tx, stream_source)(value.stream) return StreamContextVariable.create(self.tx, stream_var) - elif isinstance(value, _StreamBase): + elif isinstance(value, torch.Stream): self.install_guards(GuardBuilder.ID_MATCH) stream_proxy = self.tx.output.create_proxy( "call_function", @@ -847,7 +846,7 @@ class VariableBuilder: elif isinstance(value, torch._C._SDPBackend): self.install_guards(GuardBuilder.ID_MATCH) return ConstantVariable(value) - elif isinstance(value, _EventBase): + elif isinstance(value, torch.Event): self.install_guards(GuardBuilder.ID_MATCH) torch._dynamo.utils.store_user_object_weakref(value) event_proxy = self.tx.output.create_proxy( @@ -2265,7 +2264,7 @@ def wrap_fx_proxy_cls( return SymNodeVariable(proxy, example_value, **options) elif ( inspect.isclass(proxy.node.target) - and issubclass(proxy.node.target, _StreamBase) + and issubclass(proxy.node.target, torch.Stream) ) or proxy.node.target in [ device_interface.current_stream for _, device_interface in get_registered_device_interfaces() @@ -2273,7 +2272,8 @@ def wrap_fx_proxy_cls( set_example_value(proxy.node, example_value) return StreamVariable(proxy, example_value, example_value.device, **options) elif ( - inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase) + inspect.isclass(proxy.node.target) + and issubclass(proxy.node.target, torch.Event) ) or proxy.node.target in [ device_interface.Event for _, device_interface in get_registered_device_interfaces() @@ -2285,7 +2285,7 @@ def wrap_fx_proxy_cls( return ConstantVariable(example_value, **options) elif ( example_value is not None - and isinstance(example_value, _EventBase) + and isinstance(example_value, torch.Event) and proxy.node.target == "record_event" and proxy.node.op == "call_method" ): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 478ab82de43..ed377fd463c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -13,7 +13,6 @@ import torch.fx import torch.nn from torch._guards import TracingContext from torch._logging import warning_once -from torch._streambase import _StreamBase from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type from .. import config, polyfills, variables @@ -267,7 +266,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable): assert len(args) <= 1 and len(kwargs) == 0 inf_mode = args[0].as_python_constant() if len(args) == 1 else True return InferenceModeVariable.create(tx, inf_mode) - elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): + elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream): from torch._dynamo.variables.builder import wrap_fx_proxy_cls return wrap_fx_proxy_cls( diff --git a/torch/_streambase.py b/torch/_streambase.py index 85e203a3d99..9d71120c959 100644 --- a/torch/_streambase.py +++ b/torch/_streambase.py @@ -1,46 +1,20 @@ -# mypy: allow-untyped-defs -from abc import ABC, abstractmethod +from typing_extensions import deprecated + +import torch -class _StreamBase(ABC): - r"""Base stream class abstraction for multi backends Stream to herit from""" - - @abstractmethod - def wait_event(self, event) -> None: - raise NotImplementedError - - @abstractmethod - def wait_stream(self, stream) -> None: - raise NotImplementedError - - @abstractmethod - def record_event(self, event=None) -> None: - raise NotImplementedError - - @abstractmethod - def query(self) -> bool: - raise NotImplementedError - - @abstractmethod - def synchronize(self) -> None: - raise NotImplementedError - - @abstractmethod - def __eq__(self, stream) -> bool: - raise NotImplementedError +# Preserved only for BC reasons +@deprecated( + "`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.", + category=FutureWarning, +) +class _StreamBase(torch.Stream): + pass -class _EventBase(ABC): - r"""Base Event class abstraction for multi backends Event to herit from""" - - @abstractmethod - def wait(self, stream=None) -> None: - raise NotImplementedError - - @abstractmethod - def query(self) -> bool: - raise NotImplementedError - - @abstractmethod - def synchronize(self) -> None: - raise NotImplementedError +@deprecated( + "`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.", + category=FutureWarning, +) +class _EventBase(torch.Event): + pass diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index d4ee6eb68d6..6ef0baeeaf4 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -2,7 +2,6 @@ import ctypes import torch -from torch._streambase import _EventBase, _StreamBase from torch._utils import _dummy_type @@ -12,7 +11,7 @@ if not hasattr(torch._C, "_CudaStreamBase"): torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") -class Stream(torch._C._CudaStreamBase, _StreamBase): +class Stream(torch._C._CudaStreamBase): r"""Wrapper around a CUDA stream. A CUDA stream is a linear sequence of execution that belongs to a specific @@ -138,7 +137,7 @@ class ExternalStream(Stream): return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) -class Event(torch._C._CudaEventBase, _EventBase): +class Event(torch._C._CudaEventBase): r"""Wrapper around a CUDA event. CUDA events are synchronization markers that can be used to monitor the diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index 19a7cda162f..beb438be466 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -2,9 +2,7 @@ import ctypes import torch -from torch._streambase import _EventBase, _StreamBase - -from .._utils import _dummy_type +from torch._utils import _dummy_type if not hasattr(torch._C, "_XpuStreamBase"): @@ -13,7 +11,7 @@ if not hasattr(torch._C, "_XpuStreamBase"): torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase") -class Stream(torch._C._XpuStreamBase, _StreamBase): +class Stream(torch._C._XpuStreamBase): r"""Wrapper around a XPU stream. A XPU stream is a linear sequence of execution that belongs to a specific @@ -98,7 +96,7 @@ class Stream(torch._C._XpuStreamBase, _StreamBase): return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})" -class Event(torch._C._XpuEventBase, _EventBase): +class Event(torch._C._XpuEventBase): r"""Wrapper around a XPU event. XPU events are synchronization markers that can be used to monitor the