pytorch/torch/utils/_python_dispatch.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

184 lines
8.2 KiB
Python
Raw Normal View History

[Reland] Add python mode (#64360) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360 This PR adds a (private) enable_python_mode context manager. (see torch/utils/_python_dispatch.py). enable_python_mode accepts the type of a __torch_dispatch__ object as its argument. Whenever an operator gets called inside of the context manager, it dispatches to the __torch_dispatch__ of the passed-in type. Example usage: ``` with enable_python_mode(LoggingTensor): z = torch.empty([]) assert isinstance(z, LoggingTensor) ``` There are quite a few changes that were made to support this. First, we added TorchDispatchTypeObject, a C++ struct that represents the type of a `__torch_dispatch__` object (e.g. LoggingTensor). It holds both the PyObject* representing the class and a PyInterpreter* so we know which Python interpreter it came from. Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this is null, dispatching happens as usual. When it is non-null, we prepend the TorchDispatchTypeObject's PyObject* to the overloaded args list so that it is considered first for dispatch. To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser` works. The "overloaded args list" previously only consisted of Tensor PyObjects, but now it can have types in addition to Tensors! - We renamed `append_overloaded_arg` to `append_overloaded_arg` - We added a new `append_overloaded_type` that appends a type to overloaded_args - We added special handling in `handle_torch_dispatch_no_python_arg_parser` and `append_overloaded_arg` to handle types in addition to Tensors. Then, there is PythonMode and PythonModeTLS. - We reuse the DispatchKey::Python dispatch key as a mode key - We use PythonMode::enter and PythonMode::exit to enable/disable DispatchKey::Python and set the PythonModeTLS. - PythonModeTLS stores a TorchDispatchTypeObject as metadata. - PythonMode is in libtorch_python, and PythonModeTLS is in ATen. This split is due to the libtorch_python library boundary (because we need to save TLS in ATen/ThreadLocalState) - We modify the PythonFallbackKernel to look up the relevant TorchDispatchTypeObject (if Python Mode is active) and dispatch using it. There are two more miscellaneous changes: - internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an exclude guard. enable_python_mode currently does not handle torch.tensor and the exclude guard is to prevent a bug. Future: - This PR does not allow for the nesting of Python modes. In the future we should be able to enable this with a more sane no_dispatch API and by changing the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D30698082 Pulled By: zou3519 fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
import contextlib
from typing import Iterator, Set
import functools
from torch.utils._mode_utils import _enable_mode, _push_mode, _ModeInfo, _wrap_init, _restore_mode
from torch._C import _get_torch_dispatch_mode, _set_torch_dispatch_mode
from dataclasses import dataclass
@dataclass
class TorchDispatchModeInfo(_ModeInfo):
def __init__(self):
super().__init__(mode_name="torch_dispatch", mode_class=TorchDispatchMode,
base_mode_class=BaseTorchDispatchMode)
def get_mode(self):
return _get_torch_dispatch_mode()
def set_mode(self, mode):
return _set_torch_dispatch_mode(mode)
[Reland] Add python mode (#64360) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360 This PR adds a (private) enable_python_mode context manager. (see torch/utils/_python_dispatch.py). enable_python_mode accepts the type of a __torch_dispatch__ object as its argument. Whenever an operator gets called inside of the context manager, it dispatches to the __torch_dispatch__ of the passed-in type. Example usage: ``` with enable_python_mode(LoggingTensor): z = torch.empty([]) assert isinstance(z, LoggingTensor) ``` There are quite a few changes that were made to support this. First, we added TorchDispatchTypeObject, a C++ struct that represents the type of a `__torch_dispatch__` object (e.g. LoggingTensor). It holds both the PyObject* representing the class and a PyInterpreter* so we know which Python interpreter it came from. Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this is null, dispatching happens as usual. When it is non-null, we prepend the TorchDispatchTypeObject's PyObject* to the overloaded args list so that it is considered first for dispatch. To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser` works. The "overloaded args list" previously only consisted of Tensor PyObjects, but now it can have types in addition to Tensors! - We renamed `append_overloaded_arg` to `append_overloaded_arg` - We added a new `append_overloaded_type` that appends a type to overloaded_args - We added special handling in `handle_torch_dispatch_no_python_arg_parser` and `append_overloaded_arg` to handle types in addition to Tensors. Then, there is PythonMode and PythonModeTLS. - We reuse the DispatchKey::Python dispatch key as a mode key - We use PythonMode::enter and PythonMode::exit to enable/disable DispatchKey::Python and set the PythonModeTLS. - PythonModeTLS stores a TorchDispatchTypeObject as metadata. - PythonMode is in libtorch_python, and PythonModeTLS is in ATen. This split is due to the libtorch_python library boundary (because we need to save TLS in ATen/ThreadLocalState) - We modify the PythonFallbackKernel to look up the relevant TorchDispatchTypeObject (if Python Mode is active) and dispatch using it. There are two more miscellaneous changes: - internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an exclude guard. enable_python_mode currently does not handle torch.tensor and the exclude guard is to prevent a bug. Future: - This PR does not allow for the nesting of Python modes. In the future we should be able to enable this with a more sane no_dispatch API and by changing the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D30698082 Pulled By: zou3519 fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
[Reland] Add python mode (#64360) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360 This PR adds a (private) enable_python_mode context manager. (see torch/utils/_python_dispatch.py). enable_python_mode accepts the type of a __torch_dispatch__ object as its argument. Whenever an operator gets called inside of the context manager, it dispatches to the __torch_dispatch__ of the passed-in type. Example usage: ``` with enable_python_mode(LoggingTensor): z = torch.empty([]) assert isinstance(z, LoggingTensor) ``` There are quite a few changes that were made to support this. First, we added TorchDispatchTypeObject, a C++ struct that represents the type of a `__torch_dispatch__` object (e.g. LoggingTensor). It holds both the PyObject* representing the class and a PyInterpreter* so we know which Python interpreter it came from. Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this is null, dispatching happens as usual. When it is non-null, we prepend the TorchDispatchTypeObject's PyObject* to the overloaded args list so that it is considered first for dispatch. To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser` works. The "overloaded args list" previously only consisted of Tensor PyObjects, but now it can have types in addition to Tensors! - We renamed `append_overloaded_arg` to `append_overloaded_arg` - We added a new `append_overloaded_type` that appends a type to overloaded_args - We added special handling in `handle_torch_dispatch_no_python_arg_parser` and `append_overloaded_arg` to handle types in addition to Tensors. Then, there is PythonMode and PythonModeTLS. - We reuse the DispatchKey::Python dispatch key as a mode key - We use PythonMode::enter and PythonMode::exit to enable/disable DispatchKey::Python and set the PythonModeTLS. - PythonModeTLS stores a TorchDispatchTypeObject as metadata. - PythonMode is in libtorch_python, and PythonModeTLS is in ATen. This split is due to the libtorch_python library boundary (because we need to save TLS in ATen/ThreadLocalState) - We modify the PythonFallbackKernel to look up the relevant TorchDispatchTypeObject (if Python Mode is active) and dispatch using it. There are two more miscellaneous changes: - internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an exclude guard. enable_python_mode currently does not handle torch.tensor and the exclude guard is to prevent a bug. Future: - This PR does not allow for the nesting of Python modes. In the future we should be able to enable this with a more sane no_dispatch API and by changing the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D30698082 Pulled By: zou3519 fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
# - We need a better user-facing api for torch._C._DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
@contextlib.contextmanager
def enable_torch_dispatch_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
"""
Context manager that causes all pytorch operators to dispatch to the passed-in
type's __torch_dispatch__ function, including operations that accept no tensors
but return a tensor.
This function is non-compositional; if there is already an existing mode,
it will raise an error
This function is safe to use inside a ``__torch_dispatch__`` mode handler,
as the mode is guaranteed to be disabled in this context. You can use
this context manager to reinstate the mode so that calls to overridable
APIs recursively call back into your mode handler (this can easily cause
infinite loops, so use with care!)
enable_torch_dispatch_mode is affected by _DisableTorchDispatch.
Args:
mode (:class:`TorchDispatchMode`, Tensor-like class, or None): the
mode to set as current mode. If you pass a Tensor-like class,
it will be treated as a non-compositional mode with no state,
which is convenient if you have an existing tensor subclass
that you'd like to apply globally in a quick and dirty way.
Passing None will disable the current mode.
replace (:class:`TorchDispatchMode` or Tensor-like class): the
mode to replace. You can use this argument to change the mode in
a situation where you know what the current mode is (and you are
intentionally overwriting it.) If you don't know what the current
mode is, use ``ignore_preexisting`` instead.
ignore_preexisting (bool): if True, ignore any preexisting mode
and overwrite it with the passed mode.
"""
return _enable_mode(mode, mode_info=TorchDispatchModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting)
def _wrap_torch_dispatch(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if isinstance(f, classmethod):
raise RuntimeError("TorchDispatchMode's torch_dispatch function " +
"should be a normal method not a class method")
inner = getattr(self, "inner", None)
with enable_torch_dispatch_mode(inner):
return f(self, *args, **kwargs)
return wrapped
# Implementation note, since this is based on TorchFunctionMode, this had the
# same dilemma: I had a choice about how much of mode stacks
# to implement in Python versus in C++. At time of writing, I did not care
# too much about implementation efficiency; however, I do care about making it
# hard for users to implement modes in the wrong way. In the end, it turned
# out to be possible to implement mode stacks entirely from userland, with the
# C++ API providing only _get_torch_dispatch_mode() and
# _set_torch_dispatch_mode(), so I opted to provide some unsafe C++ bindings and
# have the bulk of the logic for managing the stack in Python, which helped
# simplify the C++ API surface. It would also have been valid to build in the
# notion of mode stack directly into C++ but in this design it's substantially
# more difficult to interact with TorchDispatchModeMeta.
class TorchDispatchModeMeta(type):
"""
Metaclass for :class:`TorchDispatchMode`; it does two things:
* Adds an implicit ``inner`` kwarg to ``__init__``, to
allow the modes to be chained together to form a stack.
* Reenables the inner mode, so that by default PyTorch API calls
will compositionally proceed to the next mode on the stack.
The default behavior for the second bullet is important, as it is easy to
accidentally write ``_wrap_torch_dispatch`` implementations that are not
compositional, and the wrapping here makes the obvious code do the
right thing (aka, this is why there is a metaclass).
"""
def __new__(metacls, name, bases, dct):
if '__init__' in dct:
dct['__init__'] = _wrap_init(dct['__init__'])
if '__torch_dispatch__' in dct:
dct['__torch_dispatch__'] = _wrap_torch_dispatch(dct['__torch_dispatch__'])
return super().__new__(metacls, name, bases, dct)
class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
"""
A ``TorchDispatchMode`` allows you to override the meaning of all
``__torch_dispatch__`` overrideable functions within a dynamic scope,
without having to actually create a tensor subclass or manually
monkey-patch functions in the PyTorch API. Some common situations
where you should use a mode:
* You want to override the meaning of factory functions, or other
functions that do not otherwise take a tensor as an argument
(these cannot be overridden with tensor subclasses).
* You want to override the behavior of all functions without needing
to wrap your inputs in tensor subclasses; e.g., if you are just
interested in logging intermediate computations.
* You want to control the order of execution of various tensor
subclasses explicitly, rather than implicitly via the return of
``NotImplemented``.
Independent subclasses of :class:`TorchDispatchMode` are compositional:
modes can be pushed onto a stack with :func:`push_torch_dispatch_mode`.
When you call functions in the PyTorch API inside your
``__torch_dispatch__`` implementation, by default, they will forward on to
the next mode on the mode stack. If you want recursively call back into
your current ``__torch_dispatch__`` implementation, either explicitly
invoke ``self.__torch_dispatch__(...)``, or use the context manager
``__torch_dispatch__(self, replace=self.inner)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
# Force metaclass to generate constructor at the base of the hierarchy
def __init__(self):
self.ancestors: Set[TorchDispatchMode]
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
def __enter__(self):
old = _get_torch_dispatch_mode()
if hasattr(self, "inner"):
raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore")
else:
self.inner = old
if old is None:
self.ancestors = set()
else:
self.ancestors = self.inner.ancestors.union({self.inner})
_set_torch_dispatch_mode(self)
def __exit__(self, exc_type, exc_val, exc_tb):
_set_torch_dispatch_mode(self.inner)
@contextlib.contextmanager
def restore(self):
return _restore_mode(self, mode_info=TorchDispatchModeInfo())
@classmethod
def push(cls, *args, **kwargs):
return push_torch_dispatch_mode(functools.partial(cls, *args, **kwargs))
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
@contextlib.contextmanager
def push_torch_dispatch_mode(ctor) -> Iterator[object]:
return _push_mode(ctor, mode_info=TorchDispatchModeInfo())