2022-03-24 14:24:30 +00:00
|
|
|
# Owner(s): ["module: __torch_dispatch__"]
|
2021-10-29 19:15:30 +00:00
|
|
|
|
2022-02-24 15:59:39 +00:00
|
|
|
import tempfile
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
import torch
|
2022-02-24 15:59:39 +00:00
|
|
|
from copy import deepcopy
|
2024-01-26 19:08:49 +00:00
|
|
|
from torch.library import Library, impl, fallthrough_kernel, _scoped_library
|
2023-10-26 20:00:32 +00:00
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
Support registering op returning symint in python (#95240)
Running an operator registered in python returning a symint will result in the following error:
```
RuntimeError: Unable to cast Python instance of type <class 'torch.SymInt'> to C++ type 'long'
```
The interaction of 2 things make the issue being triggered:
- We use boxed kernel here. For boxed kernel, we need convert py::object to IValue in torch/csrc/autograd/python_variable.cpp pushPyOutToStack .
- In the schema parsing code in torch/csrc/jit/frontend/schema_type_parser.cpp SchemaTypeParser::parseFakeAndRealType , if a SymInt is found, we register a Int type instead (not sure why we do this), and register SymInt as the real type.
The result is we would convert an SymInt to int in pushPyOutToStack and cause the issue.
The fix is to use real type when we convert py::object to IValue.
BTW, registering the same op using C++ API does not trigger the issue.
```
TORCH_LIBRARY(clib, m) {
m.def("sqsum(SymInt a, SymInt b) -> SymInt", [](SymInt a, SymInt b) -> SymInt {
return a * a + b * b;
});
}
```
The reason is, the kernel registered in C++ is unboxed kernel and it does not trigger the code path above that converts an py::object to IValue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95240
Approved by: https://github.com/larryliu0820, https://github.com/ezyang
2023-02-22 04:56:37 +00:00
|
|
|
from torch import SymInt
|
|
|
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
from torch.cuda.jiterator import _create_jit_fn
|
2022-05-19 17:35:06 +00:00
|
|
|
import unittest
|
2023-07-26 23:46:28 +00:00
|
|
|
from torch.testing._internal.common_utils import * # noqa: F403
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
from torch.utils._mode_utils import no_dispatch, all_same_mode
|
2022-04-20 14:01:17 +00:00
|
|
|
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
|
2022-05-31 14:02:17 +00:00
|
|
|
log_input, capture_logs, capture_logs_with_logging_tensor_mode
|
reorder proxy / fake modes so they always run last (#104482)
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates:
(1) TLS changes in `TorchDispatchModeTLS.h/cpp`.
I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack).
`TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes.
`TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes".
`TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes).
There were also some mild codegen changes to support the new enum
(2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py`
The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`.
I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly.
(2) dispatching logic in `python_arg_parser.cpp`
This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now:
```
(a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last)
(b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses)
(c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses)
```
Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass.
How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key.
(3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends).
----- original description below -----
The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit
Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that.
**(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode**
This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes.
I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set.
This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler.
**(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()**
This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs.
**(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's**
Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active.
**(4) Allow traceable tensor subclasses to access storages from python**
This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable"
**(5) Fixed subclass fakeification**
@wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct.
**(6) make tensor subclasses resizeable**
Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`.
**(7) Added a basic DoubleTensor subclass for testing**
I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482
Approved by: https://github.com/ezyang
ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
|
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
2022-08-16 03:03:12 +00:00
|
|
|
from torch.utils._pytree import tree_map, tree_map_only
|
2023-10-30 00:05:29 +00:00
|
|
|
from torch.utils import _pytree as pytree
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack
|
2023-06-01 18:44:57 +00:00
|
|
|
from torch._custom_op.functional import register_functional_op
|
2023-09-08 13:33:07 +00:00
|
|
|
from torch._C import DispatchKeySet, DispatchKey
|
2023-04-18 13:51:23 +00:00
|
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 02:43:08 +00:00
|
|
|
from torch.testing._internal.common_device_type import ops
|
|
|
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
|
|
|
from torch.testing._internal.custom_op_db import custom_op_db
|
|
|
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
|
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
2021-08-12 18:39:31 +00:00
|
|
|
import logging
|
2023-04-18 13:51:23 +00:00
|
|
|
import sys
|
2023-04-27 21:27:52 +00:00
|
|
|
import torch._dynamo
|
2023-04-18 13:51:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDispatcherPythonBindings(TestCase):
|
|
|
|
|
def test_call_boxed(self) -> None:
|
|
|
|
|
sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "")
|
|
|
|
|
x = torch.randn(3)
|
|
|
|
|
y = torch._C._dispatch_call_boxed(sin, x)
|
|
|
|
|
self.assertEqual(y, x.sin())
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
2022-06-10 03:02:28 +00:00
|
|
|
|
2022-05-04 21:51:09 +00:00
|
|
|
class TestPythonRegistration(TestCase):
|
2023-06-01 18:06:45 +00:00
|
|
|
test_ns = '_test_python_registration'
|
|
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
|
if hasattr(torch.ops, self.test_ns):
|
|
|
|
|
del torch.ops._test_python_registration
|
|
|
|
|
|
2022-05-04 21:51:09 +00:00
|
|
|
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
|
|
|
|
x = torch.tensor([1, 2])
|
2024-01-26 19:08:49 +00:00
|
|
|
with _scoped_library("aten", "IMPL") as my_lib2:
|
|
|
|
|
with _scoped_library("aten", "IMPL") as my_lib1:
|
|
|
|
|
# Example 1
|
|
|
|
|
def my_neg(*args, **kwargs):
|
|
|
|
|
return args[0]._neg_view()
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
# Now we are secretly making the operator a view op so autograd needs to know how
|
|
|
|
|
# to handle it
|
|
|
|
|
my_lib1.impl('neg', my_neg, "AutogradCPU")
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
self.assertTrue(torch.neg(x).is_neg())
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
# RuntimeError: impl("aten::neg", ...):
|
|
|
|
|
# Explicitly provided namespace (aten) in operator name does not match ...
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("foo", "DEF") as my_lib3:
|
|
|
|
|
my_lib3.define("neg(Tensor self) -> Tensor")
|
|
|
|
|
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
# Example 2
|
|
|
|
|
def my_mul(*args, **kwargs):
|
|
|
|
|
return torch.zeros_like(args[0])
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
# torch.ops.aten.mul.Tensor
|
|
|
|
|
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
y = torch._efficientzerotensor(2)
|
|
|
|
|
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
|
|
|
|
|
# combination if someone overrided the behavior for the same before them
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
|
|
|
|
|
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-01-26 19:08:49 +00:00
|
|
|
# Validate that lib2 is not affected by removing lib1
|
|
|
|
|
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
2022-05-04 21:51:09 +00:00
|
|
|
|
|
|
|
|
# Validate that the old behavior is restored for neg and mul
|
|
|
|
|
self.assertFalse(torch.neg(x).is_neg())
|
|
|
|
|
self.assertTrue(torch.mul(x, y)._is_zerotensor())
|
|
|
|
|
|
2022-06-02 14:37:02 +00:00
|
|
|
def test_error_if_fn_not_callable(self):
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "Input function is required to be a callable"):
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("aten", "IMPL") as my_lib:
|
|
|
|
|
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
|
2022-06-02 14:37:02 +00:00
|
|
|
|
2023-05-22 15:34:45 +00:00
|
|
|
def test_finalizer(self):
|
|
|
|
|
impls_refcnt = sys.getrefcount(torch.library._impls)
|
2024-02-12 23:30:08 +00:00
|
|
|
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
2023-05-22 15:34:45 +00:00
|
|
|
lib.define("foo123(Tensor x) -> Tensor")
|
|
|
|
|
|
|
|
|
|
# 1 for `lib`, 1 for sys.getrefcount
|
|
|
|
|
self.assertEqual(sys.getrefcount(lib), 2)
|
|
|
|
|
# We gained an additional reference that gets cleared when the finalizer runs
|
|
|
|
|
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1)
|
|
|
|
|
# 1 for `lib`
|
|
|
|
|
# 1 for the finalizer
|
|
|
|
|
# 1 for sys.getrefcount
|
|
|
|
|
self.assertEqual(sys.getrefcount(lib._op_impls), 3)
|
|
|
|
|
|
|
|
|
|
def foo123(x):
|
|
|
|
|
pass
|
|
|
|
|
|
2023-06-01 18:06:45 +00:00
|
|
|
lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
|
|
|
|
|
key = f'{self.test_ns}/foo123/CPU'
|
2023-05-22 15:34:45 +00:00
|
|
|
self.assertTrue(key in torch.library._impls)
|
|
|
|
|
|
|
|
|
|
saved_op_impls = lib._op_impls
|
|
|
|
|
|
|
|
|
|
# del will definitely work if the following passes
|
|
|
|
|
self.assertEqual(sys.getrefcount(lib), 2)
|
|
|
|
|
del lib
|
|
|
|
|
|
|
|
|
|
# 1 for saved_op_impls
|
|
|
|
|
# 1 for sys.getrefcount
|
|
|
|
|
# This function should be the last user of lib._op_impls:
|
|
|
|
|
# - lib should not have a reference anymore (it was del'ed)
|
|
|
|
|
# - lib's finalizer should not have a reference anymore
|
|
|
|
|
self.assertEqual(sys.getrefcount(saved_op_impls), 2)
|
|
|
|
|
|
|
|
|
|
self.assertTrue(key not in torch.library._impls)
|
|
|
|
|
|
|
|
|
|
# lib's finalizer should not have a reference anymore
|
|
|
|
|
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt)
|
|
|
|
|
|
2022-05-04 21:51:09 +00:00
|
|
|
def test_override_cpu_sum(self) -> None:
|
|
|
|
|
# Example 1
|
|
|
|
|
run = [False]
|
|
|
|
|
|
|
|
|
|
def my_sum(*args, **kwargs):
|
|
|
|
|
run[0] = True
|
2023-04-07 18:26:35 +00:00
|
|
|
return args[0].clone()
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("aten", "IMPL") as my_lib1:
|
|
|
|
|
my_lib1.impl('aten::sum', my_sum, "CPU")
|
|
|
|
|
x = torch.tensor([1, 2])
|
|
|
|
|
self.assertEqual(torch.sum(x), x)
|
|
|
|
|
self.assertTrue(run[0])
|
2022-05-04 21:51:09 +00:00
|
|
|
# Validate that the old behavior is restored for sum
|
|
|
|
|
self.assertEqual(torch.sum(x), torch.tensor(3))
|
|
|
|
|
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
def test_override_cuda_with_jiterator(self) -> None:
|
|
|
|
|
def override_where_cuda() -> None:
|
|
|
|
|
# Example 1: Invert the behavior of where's condition input
|
|
|
|
|
not_where_code_string = '''
|
|
|
|
|
template <typename T> T inverted_where(bool cond, T a, T b){
|
|
|
|
|
return !cond ? a : b;
|
|
|
|
|
}
|
|
|
|
|
'''
|
|
|
|
|
jitted_where = _create_jit_fn(not_where_code_string)
|
|
|
|
|
|
|
|
|
|
CALLED = [False]
|
|
|
|
|
|
|
|
|
|
def inverted_where(*args, **kwargs):
|
|
|
|
|
CALLED[0] = True
|
|
|
|
|
return jitted_where(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# overriding where's cuda kernel with Jiterator generated kernel
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("aten", "IMPL") as my_lib:
|
|
|
|
|
my_lib.impl('aten::where.self', inverted_where, "CUDA")
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
device = 'cuda'
|
|
|
|
|
cond = torch.tensor([True, True, False], device=device, dtype=torch.bool)
|
|
|
|
|
x = torch.tensor([1, 2, 3], device=device)
|
|
|
|
|
y = torch.tensor([-1, -2, -3], device=device)
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
|
|
|
|
|
self.assertTrue(CALLED[0])
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
|
|
|
|
# behavior restored after deregistration
|
|
|
|
|
self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
|
|
|
|
|
|
|
|
|
|
def override_gelu_cuda() -> None:
|
|
|
|
|
# Example 2: Use relu to approximate gelu for faster compute
|
|
|
|
|
fastest_gelu_code_string = '''
|
|
|
|
|
template <typename T> T fast_gelu(T a){
|
|
|
|
|
return a > 0 ? a : 0;
|
|
|
|
|
}
|
|
|
|
|
'''
|
|
|
|
|
jitted_gelu = _create_jit_fn(fastest_gelu_code_string)
|
|
|
|
|
|
|
|
|
|
CALLED = [False]
|
|
|
|
|
|
|
|
|
|
def fast_gelu(*args, **kwargs):
|
|
|
|
|
CALLED[0] = True
|
|
|
|
|
return jitted_gelu(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# overriding gelu's cuda kernel with Jiterator generated relu kernel
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("aten", "IMPL") as my_lib:
|
|
|
|
|
my_lib.impl('aten::gelu', fast_gelu, "CUDA")
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x = torch.rand([3, 3], device='cuda', dtype=torch.float)
|
|
|
|
|
self.assertEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
|
|
|
|
self.assertTrue(CALLED[0])
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
|
|
|
|
# behavior restored after deregistration
|
|
|
|
|
self.assertNotEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
|
|
|
|
|
|
|
|
|
def override_exp_cuda() -> None:
|
|
|
|
|
# Example 3: Preventing exp from exploding for float16
|
|
|
|
|
clipped_exp_code_string = '''
|
|
|
|
|
template <typename T> T clipped_exp(T a){
|
|
|
|
|
return a > T(10.0) ? T(22026.4657948) : exp(a);
|
|
|
|
|
}
|
|
|
|
|
'''
|
|
|
|
|
jitted_exp = _create_jit_fn(clipped_exp_code_string)
|
|
|
|
|
|
|
|
|
|
CALLED = [False]
|
|
|
|
|
|
|
|
|
|
def clipped_exp(*args, **kwargs):
|
|
|
|
|
CALLED[0] = True
|
|
|
|
|
return jitted_exp(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# overriding exp's cuda kernel with clipped_exp kernel
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("aten", "IMPL") as my_lib:
|
|
|
|
|
my_lib.impl('aten::exp', clipped_exp, "CUDA")
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x = torch.tensor([0.0, 100.0], device='cuda', dtype=torch.float16)
|
|
|
|
|
self.assertEqual(torch.exp(x), torch.tensor([1.0, 22026.4657948], dtype=torch.float16))
|
|
|
|
|
self.assertTrue(CALLED[0])
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
|
|
|
|
# behavior restored after deregistration
|
|
|
|
|
self.assertEqual(torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16))
|
|
|
|
|
|
|
|
|
|
def override_add_cuda() -> None:
|
|
|
|
|
# Example 4: simulate a hardware bug, where the adder is always off by 1
|
|
|
|
|
buggy_add_code_string = '''
|
|
|
|
|
template <typename T> T buggy_add(T a, T b){
|
|
|
|
|
return a + b + T(1);
|
|
|
|
|
}
|
|
|
|
|
'''
|
|
|
|
|
jitted_add = _create_jit_fn(buggy_add_code_string)
|
|
|
|
|
|
|
|
|
|
CALLED = [False]
|
|
|
|
|
|
|
|
|
|
def buggy_add(*args, **kwargs):
|
|
|
|
|
CALLED[0] = True
|
|
|
|
|
return jitted_add(*args, **kwargs)
|
|
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("aten", "IMPL") as my_lib:
|
|
|
|
|
my_lib.impl('aten::add.Tensor', buggy_add, "CUDA")
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x_cpu = torch.rand([3, 3], device='cpu')
|
|
|
|
|
y_cpu = torch.rand([3], device='cpu')
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x_cuda = x_cpu.cuda()
|
|
|
|
|
y_cuda = y_cpu.cuda()
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
|
|
|
|
|
self.assertTrue(CALLED[0])
|
Jiterator with Python Registration (#77121)
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
2022-05-10 20:54:23 +00:00
|
|
|
|
|
|
|
|
# behavior restored after deregistration
|
|
|
|
|
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available() and not TEST_WITH_ROCM:
|
|
|
|
|
override_where_cuda()
|
|
|
|
|
override_gelu_cuda()
|
|
|
|
|
override_exp_cuda()
|
|
|
|
|
override_add_cuda()
|
|
|
|
|
|
2022-05-04 21:51:09 +00:00
|
|
|
def test_extend_library_with_dispatch_key_arg(self):
|
|
|
|
|
def my_sum(*args, **kwargs):
|
2023-04-07 18:26:35 +00:00
|
|
|
return args[0].clone()
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
|
|
|
|
|
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
|
|
|
|
|
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "inconsistent with the dispatch key"):
|
|
|
|
|
my_lib1.impl('sum', my_sum, "Conjugate")
|
|
|
|
|
my_lib1.impl('aten::sum', my_sum)
|
|
|
|
|
x = torch.tensor([1, 2])
|
|
|
|
|
self.assertEqual(torch.sum(x), x)
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2022-05-04 21:51:09 +00:00
|
|
|
def test_create_new_library(self) -> None:
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library(self.test_ns, "DEF") as my_lib1:
|
|
|
|
|
my_lib1.define("sum(Tensor self) -> Tensor")
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
# Example 1
|
|
|
|
|
@torch.library.impl(my_lib1, "sum", "CPU")
|
|
|
|
|
def my_sum(*args, **kwargs):
|
2023-04-07 18:26:35 +00:00
|
|
|
return args[0].clone()
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x = torch.tensor([1, 2])
|
|
|
|
|
op = getattr(torch.ops, self.test_ns).sum
|
|
|
|
|
self.assertEqual(op(x), x)
|
2022-05-04 21:51:09 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library(self.test_ns, "IMPL") as my_lib2:
|
|
|
|
|
# Example 2
|
|
|
|
|
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
|
|
|
|
|
def my_sum_zt(*args, **kwargs):
|
|
|
|
|
if args[0]._is_zerotensor():
|
|
|
|
|
return torch._efficientzerotensor(args[0].shape)
|
|
|
|
|
else:
|
|
|
|
|
return args[0].clone()
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
y = torch._efficientzerotensor(3)
|
|
|
|
|
self.assertTrue(op(y)._is_zerotensor())
|
|
|
|
|
self.assertEqual(op(x), x)
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
def test_create_new_library_fragment_no_existing(self):
|
|
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
|
|
|
|
|
my_lib.define("sum2(Tensor self) -> Tensor")
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
@torch.library.impl(my_lib, "sum2", "CPU")
|
|
|
|
|
def my_sum(*args, **kwargs):
|
|
|
|
|
return args[0]
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x = torch.tensor([1, 2])
|
|
|
|
|
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
|
2023-04-07 18:26:35 +00:00
|
|
|
|
|
|
|
|
def test_create_new_library_fragment_with_existing(self):
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library(self.test_ns, "DEF") as my_lib1:
|
|
|
|
|
# Create a fragment
|
|
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
|
|
|
|
|
my_lib2.define("sum4(Tensor self) -> Tensor")
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
@torch.library.impl(my_lib2, "sum4", "CPU")
|
|
|
|
|
def my_sum4(*args, **kwargs):
|
|
|
|
|
return args[0]
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x = torch.tensor([1, 2])
|
|
|
|
|
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
# Create another fragment
|
|
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
|
|
|
|
|
my_lib3.define("sum3(Tensor self) -> Tensor")
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
@torch.library.impl(my_lib3, "sum3", "CPU")
|
|
|
|
|
def my_sum3(*args, **kwargs):
|
|
|
|
|
return args[0]
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
x = torch.tensor([1, 2])
|
|
|
|
|
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
|
2023-04-07 18:26:35 +00:00
|
|
|
|
2022-05-19 17:35:06 +00:00
|
|
|
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
|
|
|
|
|
def test_alias_analysis(self):
|
|
|
|
|
def test_helper(alias_analysis=""):
|
2024-02-12 23:30:08 +00:00
|
|
|
my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901
|
2022-05-19 17:35:06 +00:00
|
|
|
|
|
|
|
|
called = [0]
|
|
|
|
|
|
|
|
|
|
@torch.library.define(my_lib1, "_op() -> None", alias_analysis=alias_analysis)
|
|
|
|
|
def _op(*args, **kwargs):
|
|
|
|
|
called[0] += 1
|
|
|
|
|
|
|
|
|
|
@torch.jit.script
|
|
|
|
|
def _test():
|
2023-06-01 18:06:45 +00:00
|
|
|
torch.ops._test_python_registration._op()
|
2022-05-19 17:35:06 +00:00
|
|
|
|
2023-06-01 18:06:45 +00:00
|
|
|
assert "_test_python_registration::_op" in str(_test.graph)
|
2022-05-19 17:35:06 +00:00
|
|
|
|
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
|
test_helper("") # alias_analysis="FROM_SCHEMA"
|
|
|
|
|
|
|
|
|
|
test_helper("CONSERVATIVE")
|
|
|
|
|
|
2022-07-05 21:49:29 +00:00
|
|
|
def test_error_for_unsupported_ns_or_kind(self) -> None:
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "Unsupported kind"):
|
2024-02-12 23:30:08 +00:00
|
|
|
my_lib1 = Library("myns", "BLA") # noqa: TOR901
|
2022-07-05 21:49:29 +00:00
|
|
|
|
2023-04-07 18:26:35 +00:00
|
|
|
for kind in ('DEF', 'FRAGMENT'):
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "reserved namespace"):
|
2024-02-12 23:30:08 +00:00
|
|
|
my_lib1 = Library("prim", kind) # noqa: TOR901
|
2022-06-10 03:02:28 +00:00
|
|
|
|
Support registering op returning symint in python (#95240)
Running an operator registered in python returning a symint will result in the following error:
```
RuntimeError: Unable to cast Python instance of type <class 'torch.SymInt'> to C++ type 'long'
```
The interaction of 2 things make the issue being triggered:
- We use boxed kernel here. For boxed kernel, we need convert py::object to IValue in torch/csrc/autograd/python_variable.cpp pushPyOutToStack .
- In the schema parsing code in torch/csrc/jit/frontend/schema_type_parser.cpp SchemaTypeParser::parseFakeAndRealType , if a SymInt is found, we register a Int type instead (not sure why we do this), and register SymInt as the real type.
The result is we would convert an SymInt to int in pushPyOutToStack and cause the issue.
The fix is to use real type when we convert py::object to IValue.
BTW, registering the same op using C++ API does not trigger the issue.
```
TORCH_LIBRARY(clib, m) {
m.def("sqsum(SymInt a, SymInt b) -> SymInt", [](SymInt a, SymInt b) -> SymInt {
return a * a + b * b;
});
}
```
The reason is, the kernel registered in C++ is unboxed kernel and it does not trigger the code path above that converts an py::object to IValue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95240
Approved by: https://github.com/larryliu0820, https://github.com/ezyang
2023-02-22 04:56:37 +00:00
|
|
|
def test_returning_symint(self) -> None:
|
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
|
fake_tensor_mode = FakeTensorMode(shape_env=shape_env)
|
|
|
|
|
|
|
|
|
|
ft = fake_tensor_mode.from_tensor(torch.rand(2, 3))
|
|
|
|
|
|
|
|
|
|
s0, s1 = ft.shape
|
|
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library(self.test_ns, "DEF") as tlib:
|
|
|
|
|
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
|
Support registering op returning symint in python (#95240)
Running an operator registered in python returning a symint will result in the following error:
```
RuntimeError: Unable to cast Python instance of type <class 'torch.SymInt'> to C++ type 'long'
```
The interaction of 2 things make the issue being triggered:
- We use boxed kernel here. For boxed kernel, we need convert py::object to IValue in torch/csrc/autograd/python_variable.cpp pushPyOutToStack .
- In the schema parsing code in torch/csrc/jit/frontend/schema_type_parser.cpp SchemaTypeParser::parseFakeAndRealType , if a SymInt is found, we register a Int type instead (not sure why we do this), and register SymInt as the real type.
The result is we would convert an SymInt to int in pushPyOutToStack and cause the issue.
The fix is to use real type when we convert py::object to IValue.
BTW, registering the same op using C++ API does not trigger the issue.
```
TORCH_LIBRARY(clib, m) {
m.def("sqsum(SymInt a, SymInt b) -> SymInt", [](SymInt a, SymInt b) -> SymInt {
return a * a + b * b;
});
}
```
The reason is, the kernel registered in C++ is unboxed kernel and it does not trigger the code path above that converts an py::object to IValue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95240
Approved by: https://github.com/larryliu0820, https://github.com/ezyang
2023-02-22 04:56:37 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
|
|
|
|
|
def sqsum(a: SymInt, b: SymInt):
|
|
|
|
|
return a * a + b * b
|
Support registering op returning symint in python (#95240)
Running an operator registered in python returning a symint will result in the following error:
```
RuntimeError: Unable to cast Python instance of type <class 'torch.SymInt'> to C++ type 'long'
```
The interaction of 2 things make the issue being triggered:
- We use boxed kernel here. For boxed kernel, we need convert py::object to IValue in torch/csrc/autograd/python_variable.cpp pushPyOutToStack .
- In the schema parsing code in torch/csrc/jit/frontend/schema_type_parser.cpp SchemaTypeParser::parseFakeAndRealType , if a SymInt is found, we register a Int type instead (not sure why we do this), and register SymInt as the real type.
The result is we would convert an SymInt to int in pushPyOutToStack and cause the issue.
The fix is to use real type when we convert py::object to IValue.
BTW, registering the same op using C++ API does not trigger the issue.
```
TORCH_LIBRARY(clib, m) {
m.def("sqsum(SymInt a, SymInt b) -> SymInt", [](SymInt a, SymInt b) -> SymInt {
return a * a + b * b;
});
}
```
The reason is, the kernel registered in C++ is unboxed kernel and it does not trigger the code path above that converts an py::object to IValue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95240
Approved by: https://github.com/larryliu0820, https://github.com/ezyang
2023-02-22 04:56:37 +00:00
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
|
|
|
|
|
out_val = shape_env.evaluate_expr(out.node.expr)
|
2023-07-19 14:40:18 +00:00
|
|
|
self.assertEqual(out_val, 13)
|
Support registering op returning symint in python (#95240)
Running an operator registered in python returning a symint will result in the following error:
```
RuntimeError: Unable to cast Python instance of type <class 'torch.SymInt'> to C++ type 'long'
```
The interaction of 2 things make the issue being triggered:
- We use boxed kernel here. For boxed kernel, we need convert py::object to IValue in torch/csrc/autograd/python_variable.cpp pushPyOutToStack .
- In the schema parsing code in torch/csrc/jit/frontend/schema_type_parser.cpp SchemaTypeParser::parseFakeAndRealType , if a SymInt is found, we register a Int type instead (not sure why we do this), and register SymInt as the real type.
The result is we would convert an SymInt to int in pushPyOutToStack and cause the issue.
The fix is to use real type when we convert py::object to IValue.
BTW, registering the same op using C++ API does not trigger the issue.
```
TORCH_LIBRARY(clib, m) {
m.def("sqsum(SymInt a, SymInt b) -> SymInt", [](SymInt a, SymInt b) -> SymInt {
return a * a + b * b;
});
}
```
The reason is, the kernel registered in C++ is unboxed kernel and it does not trigger the code path above that converts an py::object to IValue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95240
Approved by: https://github.com/larryliu0820, https://github.com/ezyang
2023-02-22 04:56:37 +00:00
|
|
|
|
2023-06-01 18:44:57 +00:00
|
|
|
def test_register_functional_op_error_cases(self):
|
2024-01-26 19:08:49 +00:00
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
|
|
|
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
|
|
|
|
|
register_functional_op(lib, "abs", torch.ops.aten.abs_)
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
|
|
|
|
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
|
|
|
|
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
|
|
|
|
|
|
|
|
|
|
schemas = [
|
|
|
|
|
'foo(Tensor x, Tensor(a!)[] y) -> ()',
|
|
|
|
|
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
|
|
|
|
|
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
|
|
|
|
|
]
|
2023-06-01 18:44:57 +00:00
|
|
|
|
|
|
|
|
for schema in schemas:
|
2024-01-26 19:08:49 +00:00
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
2023-06-01 18:44:57 +00:00
|
|
|
lib.define(schema)
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "NYI"):
|
|
|
|
|
register_functional_op(
|
|
|
|
|
lib,
|
|
|
|
|
"foo_functional",
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default)
|
|
|
|
|
|
|
|
|
|
def _check_is_functional_variant(self, mutable_op, functional_op, args):
|
|
|
|
|
# functional op should not mutate
|
|
|
|
|
cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
|
|
|
|
|
functional_result = functional_op(*cloned_args)
|
|
|
|
|
self.assertEqual(cloned_args, args)
|
|
|
|
|
|
|
|
|
|
# check functional_result includes mutable_result
|
|
|
|
|
mutable_result = mutable_op(*cloned_args)
|
|
|
|
|
if mutable_result is None:
|
|
|
|
|
flat_mutable_result = []
|
|
|
|
|
else:
|
2023-10-30 00:05:29 +00:00
|
|
|
flat_mutable_result = pytree.tree_leaves(mutable_result)
|
|
|
|
|
flat_functional_result = pytree.tree_leaves(functional_result)
|
2023-06-01 18:44:57 +00:00
|
|
|
assert len(flat_functional_result) > len(flat_mutable_result)
|
|
|
|
|
self.assertEqual(flat_functional_result[:len(flat_mutable_result)], flat_mutable_result)
|
|
|
|
|
|
|
|
|
|
# check rest of functional_result is the mutated args
|
|
|
|
|
mutated_args = [maybe_mutated_arg for maybe_mutated_arg, arg in zip(cloned_args, args)
|
2024-01-03 06:04:44 +00:00
|
|
|
if not (maybe_mutated_arg is not None and arg is not None and torch.allclose(maybe_mutated_arg, arg))]
|
2023-06-01 18:44:57 +00:00
|
|
|
self.assertEqual(flat_functional_result[len(flat_mutable_result):], mutated_args)
|
|
|
|
|
|
|
|
|
|
# check that functionalization kernel was indeed registered
|
|
|
|
|
def fn(*args):
|
|
|
|
|
cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
|
|
|
|
|
mutable_op(*cloned_args)
|
|
|
|
|
return cloned_args
|
|
|
|
|
|
|
|
|
|
gm = make_fx(torch.func.functionalize(fn))(*args)
|
|
|
|
|
has_functional_op = False
|
|
|
|
|
for node in gm.graph.nodes:
|
|
|
|
|
self.assertFalse(node.target is mutable_op)
|
|
|
|
|
if node.target is functional_op:
|
|
|
|
|
has_functional_op = True
|
|
|
|
|
self.assertTrue(has_functional_op)
|
|
|
|
|
|
|
|
|
|
def test_register_functional_op_no_returns(self):
|
2024-01-26 19:08:49 +00:00
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
|
|
|
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
|
|
|
|
|
|
|
|
|
|
def foo_impl(x, y, z, w):
|
|
|
|
|
y.fill_(3.14)
|
|
|
|
|
w.fill_(2.71)
|
|
|
|
|
|
|
|
|
|
lib.impl('foo', foo_impl, 'CPU')
|
|
|
|
|
register_functional_op(
|
|
|
|
|
lib,
|
|
|
|
|
'foo_functional',
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default)
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
z = torch.randn([])
|
|
|
|
|
w = torch.randn([])
|
|
|
|
|
self._check_is_functional_variant(
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default,
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
2023-06-01 18:44:57 +00:00
|
|
|
|
2023-11-30 23:48:03 +00:00
|
|
|
def test_register_functional_op_with_optional(self):
|
2024-01-26 19:08:49 +00:00
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
|
|
|
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
|
|
|
|
|
|
|
|
|
|
def foo_impl(x, y, z, w):
|
|
|
|
|
y.fill_(3.14)
|
|
|
|
|
z.fill_(2.71)
|
|
|
|
|
if w is not None:
|
|
|
|
|
w.fill_(1.618)
|
|
|
|
|
|
|
|
|
|
lib.impl('foo', foo_impl, 'CPU')
|
|
|
|
|
register_functional_op(
|
|
|
|
|
lib,
|
|
|
|
|
'foo_functional',
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default)
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
z = torch.randn([])
|
|
|
|
|
w = torch.randn([])
|
|
|
|
|
self._check_is_functional_variant(
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default,
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
|
|
|
|
self._check_is_functional_variant(
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default,
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
|
2023-11-30 23:48:03 +00:00
|
|
|
|
2023-06-01 18:44:57 +00:00
|
|
|
def test_register_functional_op_one_return(self):
|
2024-01-26 19:08:49 +00:00
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
|
|
|
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
|
|
|
|
|
|
|
|
|
|
def foo_impl(x, y, z, w):
|
|
|
|
|
y.fill_(3.14)
|
|
|
|
|
w.fill_(2.71)
|
|
|
|
|
z.fill_(0.99)
|
|
|
|
|
return x.clone()
|
|
|
|
|
|
|
|
|
|
lib.impl('foo', foo_impl, 'CPU')
|
|
|
|
|
register_functional_op(
|
|
|
|
|
lib,
|
|
|
|
|
"foo_functional",
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default)
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
z = torch.randn([])
|
|
|
|
|
w = torch.randn([])
|
|
|
|
|
self._check_is_functional_variant(
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default,
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
2023-06-01 18:44:57 +00:00
|
|
|
|
|
|
|
|
def test_register_functional_op_multiple_returns(self):
|
2024-01-26 19:08:49 +00:00
|
|
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
|
|
|
|
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
|
|
|
|
|
|
|
|
|
|
def foo_impl(x, y, z, w):
|
|
|
|
|
y.fill_(3.14)
|
|
|
|
|
w.fill_(2.71)
|
|
|
|
|
return x.clone(), z.clone()
|
|
|
|
|
|
|
|
|
|
lib.impl('foo', foo_impl, 'CPU')
|
|
|
|
|
register_functional_op(
|
|
|
|
|
lib,
|
|
|
|
|
'foo_functional',
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default)
|
|
|
|
|
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
z = torch.randn([])
|
|
|
|
|
w = torch.randn([])
|
|
|
|
|
self._check_is_functional_variant(
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo.default,
|
|
|
|
|
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
2023-06-01 18:44:57 +00:00
|
|
|
|
2023-07-28 16:12:07 +00:00
|
|
|
def test_register_fallthrough(self):
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library('aten', 'IMPL') as my_lib:
|
2023-07-28 16:12:07 +00:00
|
|
|
my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
|
|
|
|
|
|
|
|
|
|
a = torch.randn(2, 3, device='cpu', dtype=torch.float32)
|
|
|
|
|
b = torch.randn(3, 2, device='cpu', dtype=torch.float32)
|
|
|
|
|
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
|
|
|
|
# dtype for mm should be float32 since we registered a fallthrough
|
|
|
|
|
self.assertEqual(torch.mm(a, b).dtype, torch.float32)
|
|
|
|
|
# ops that don't have a fallthrough registered should not be affected
|
|
|
|
|
self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
|
|
|
|
# default behavior should have been restored
|
|
|
|
|
self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
|
|
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
class TestPythonDispatch(TestCase):
|
|
|
|
|
def test_basic(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
2022-02-14 20:05:41 +00:00
|
|
|
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
log_input("x", x)
|
|
|
|
|
y = x * x
|
|
|
|
|
saved_x = y.grad_fn._saved_self
|
|
|
|
|
grad_y = LoggingTensor(torch.tensor([1.0]))
|
|
|
|
|
log_input("grad_y", grad_y)
|
2021-08-12 18:39:31 +00:00
|
|
|
g, = torch.autograd.grad((y,), (x,), (grad_y,))
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
|
|
|
|
self.assertEqual(g.elem, torch.tensor([6.0]))
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
self.assertEqual(saved_x, x)
|
|
|
|
|
self.assertEqual(saved_x._version, x._version)
|
|
|
|
|
x.add_(2)
|
|
|
|
|
self.assertEqual(saved_x, x)
|
|
|
|
|
# TODO: figure out why broken
|
|
|
|
|
# self.assertEqual(saved_x._version, x._version)
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[1] = input('x')
|
|
|
|
|
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
|
|
|
|
|
$2: f32[1] = input('grad_y')
|
|
|
|
|
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
|
|
|
|
|
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
|
|
|
|
|
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)''')
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
|
|
|
|
def test_out(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
y = LoggingTensor(torch.zeros(1))
|
|
|
|
|
log_input("x", x)
|
|
|
|
|
log_input("y", y)
|
|
|
|
|
torch.abs(x, out=y)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(y.elem, torch.ones(1))
|
|
|
|
|
# TODO: arguably this shouldn't pass and we should complain
|
|
|
|
|
# that out isn't a kwarg
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[1] = input('x')
|
|
|
|
|
$1: f32[1] = input('y')
|
|
|
|
|
$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)''')
|
2021-08-12 18:39:31 +00:00
|
|
|
|
2021-08-09 16:59:01 +00:00
|
|
|
def test_kwarg_only(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
y = LoggingTensor(torch.ones(1, 1))
|
|
|
|
|
z = LoggingTensor(torch.ones(1))
|
|
|
|
|
log_input("x", x)
|
|
|
|
|
log_input("y", y)
|
|
|
|
|
log_input("z", z)
|
|
|
|
|
torch.addmv(x, y, z)
|
|
|
|
|
torch.addmv(x, y, z, beta=1)
|
|
|
|
|
torch.addmv(x, y, z, beta=2)
|
|
|
|
|
torch.addmv(x, y, z, alpha=2)
|
|
|
|
|
torch.addmv(x, y, z, beta=2, alpha=2)
|
|
|
|
|
|
|
|
|
|
# The expectation is that beta/alpha don't show up when they're
|
|
|
|
|
# defaulted. This is even if the user explicitly specified it.
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[1] = input('x')
|
|
|
|
|
$1: f32[1, 1] = input('y')
|
|
|
|
|
$2: f32[1] = input('z')
|
|
|
|
|
$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
|
|
|
|
|
$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
|
|
|
|
|
$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
|
|
|
|
|
$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
|
|
|
|
|
$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
|
2021-08-09 16:59:01 +00:00
|
|
|
|
|
|
|
|
def test_kwarg_only_and_positional_default(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
log_input("x", x)
|
2022-07-13 18:38:12 +00:00
|
|
|
torch.ops.aten._foobar(x)
|
|
|
|
|
torch.ops.aten._foobar(x, False)
|
|
|
|
|
torch.ops.aten._foobar(x, arg3=False)
|
|
|
|
|
torch.ops.aten._foobar(x, False, arg3=False)
|
2021-08-09 16:59:01 +00:00
|
|
|
|
2022-07-13 18:38:12 +00:00
|
|
|
# What we are testing here is that we omit arg2
|
2021-08-09 16:59:01 +00:00
|
|
|
# if it is defaulted, even if a kwarg is set
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[1] = input('x')
|
|
|
|
|
$1: f32[1] = torch._ops.aten._foobar.default($0)
|
|
|
|
|
$2: f32[1] = torch._ops.aten._foobar.default($0, False)
|
|
|
|
|
$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
|
|
|
|
|
$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)''')
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
2022-05-16 16:38:52 +00:00
|
|
|
def test_produce_real_type(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(2, 2))
|
|
|
|
|
log_input("x", x)
|
|
|
|
|
x.to(dtype=torch.double) # non-optional dtype
|
|
|
|
|
torch.cumprod(x, 0, dtype=torch.double) # optional dtype
|
|
|
|
|
x[:, 1].contiguous(memory_format=torch.contiguous_format) # optional memory format
|
|
|
|
|
# There doesn't appear to be any layout signatures which are
|
|
|
|
|
# triggerable using tensor subclasses (need to use a mode)
|
|
|
|
|
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[2, 2] = input('x')
|
|
|
|
|
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
|
|
|
|
|
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
|
|
|
|
|
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
|
|
|
|
|
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
|
|
|
|
|
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''')
|
2022-05-16 16:38:52 +00:00
|
|
|
|
2022-11-11 14:33:41 +00:00
|
|
|
def test_optional_tensor_list(self) -> None:
|
|
|
|
|
def weird(xs):
|
2023-01-14 07:32:39 +00:00
|
|
|
print("woof")
|
2022-11-11 14:33:41 +00:00
|
|
|
return torch.empty(())
|
|
|
|
|
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("my_lib", "DEF") as my_lib:
|
|
|
|
|
my_lib.define("weird(Tensor?[] self) -> Tensor")
|
|
|
|
|
my_lib.impl("weird", weird, "CPU")
|
|
|
|
|
with capture_logs() as logs:
|
|
|
|
|
x = LoggingTensor(torch.ones(2, 2))
|
|
|
|
|
log_input("x", x)
|
|
|
|
|
torch.ops.my_lib.weird.default([None, x])
|
2022-11-11 14:33:41 +00:00
|
|
|
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[2, 2] = input('x')
|
|
|
|
|
$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''')
|
2022-11-11 14:33:41 +00:00
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
def test_list_ret(self) -> None:
|
|
|
|
|
# test all sequence types are permissible returns
|
|
|
|
|
for list_type in (list, tuple):
|
2023-09-25 19:10:19 +00:00
|
|
|
class A(torch._C.TensorBase):
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
2022-03-07 22:32:41 +00:00
|
|
|
if func.overloadpacket == torch.ops.aten.split:
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
with no_dispatch():
|
|
|
|
|
return list_type(torch.split(*args))
|
|
|
|
|
else:
|
|
|
|
|
raise AssertionError(f"unrecognized func: {func}")
|
|
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
torch.split(A(torch.tensor([0, 1])), 2),
|
2021-08-12 18:39:31 +00:00
|
|
|
torch.split(torch.tensor([0, 1]), 2)
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def test_invalid_ret(self) -> None:
|
|
|
|
|
# test invalid return gets reasonable error message
|
2023-09-25 19:10:19 +00:00
|
|
|
class A(torch._C.TensorBase):
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
return "arf"
|
|
|
|
|
|
2021-07-23 17:35:50 +00:00
|
|
|
# Wobbles depending on NDEBUG mode of pybind11
|
2022-03-15 19:11:08 +00:00
|
|
|
self.assertRaisesRegex(
|
2021-08-12 18:39:31 +00:00
|
|
|
RuntimeError, "Unable to cast", lambda: A(torch.zeros(1)).neg(),
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
)
|
2023-02-07 20:08:53 +00:00
|
|
|
self.assertRaisesRegex(
|
2022-01-28 16:08:55 +00:00
|
|
|
RuntimeError, "Unable to cast", lambda: A(torch.zeros(1)).detach(),
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
)
|
|
|
|
|
|
2022-01-28 16:08:55 +00:00
|
|
|
def test_detach_appears_twice_when_called_once(self) -> None:
|
|
|
|
|
with capture_logs() as logs:
|
2022-02-14 20:05:41 +00:00
|
|
|
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
|
2022-01-28 16:08:55 +00:00
|
|
|
log_input("x", x)
|
|
|
|
|
x.detach()
|
|
|
|
|
# FIXME: We actually want this to emit a single detach. However,
|
|
|
|
|
# it currently emits two, for reasons unclear to us. Leaving
|
|
|
|
|
# this test here to make sure we don't regress even further (it
|
|
|
|
|
# would be bad if calling .detach() once emits 3+ detaches).
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[1] = input('x')
|
|
|
|
|
$1: f32[1] = torch._ops.aten.detach.default($0)
|
|
|
|
|
$2: f32[1] = torch._ops.aten.detach.default($1)''')
|
2022-01-28 16:08:55 +00:00
|
|
|
|
2021-09-22 18:09:11 +00:00
|
|
|
def test_storage(self) -> None:
|
|
|
|
|
# For now, just make sure it doesn't crash. Ideally, we should
|
|
|
|
|
# return some virtual storage that is safe to work with
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
2023-08-22 01:09:46 +00:00
|
|
|
storage = x.untyped_storage()
|
|
|
|
|
self.assertRaises(RuntimeError, lambda: storage.data_ptr())
|
2021-09-22 18:09:11 +00:00
|
|
|
|
|
|
|
|
def test_make_wrapper_subclass_noalloc(self) -> None:
|
|
|
|
|
# This is ludicrously big (8TB) and this should pass because wrapper
|
|
|
|
|
# subclasses don't allocate
|
|
|
|
|
torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
|
|
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
def test_version(self) -> None:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
prev_vc = x._version
|
|
|
|
|
x.detach().add_(2)
|
|
|
|
|
cur_vc = x._version
|
|
|
|
|
self.assertNotEqual(prev_vc, cur_vc)
|
|
|
|
|
x.data.add_(2)
|
|
|
|
|
self.assertEqual(cur_vc, x._version)
|
|
|
|
|
|
2021-08-18 14:45:45 +00:00
|
|
|
def test_subclass_priority(self) -> None:
|
|
|
|
|
class ErrorA(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class ErrorB(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# The big tests for code coverage are test_precedence_semantics in
|
|
|
|
|
# test_overrides.py; this is just to make sure it is wired up at all
|
|
|
|
|
# correctly for __torch_dispatch__
|
|
|
|
|
class A(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorA
|
|
|
|
|
|
|
|
|
|
class B(A):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorB
|
|
|
|
|
|
|
|
|
|
self.assertRaises(ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))))
|
|
|
|
|
self.assertRaises(ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))))
|
|
|
|
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))))
|
|
|
|
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))))
|
|
|
|
|
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
def test_format(self) -> None:
|
|
|
|
|
x = LoggingTensor(torch.ones(1))
|
|
|
|
|
s1 = str(x)
|
|
|
|
|
s2 = repr(x)
|
|
|
|
|
s3 = f"{x}"
|
|
|
|
|
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
|
|
|
|
|
self.assertEqual(s1, s2)
|
|
|
|
|
self.assertEqual(s1, s3)
|
|
|
|
|
|
|
|
|
|
def test_custom_autograd(self) -> None:
|
|
|
|
|
escape = [None]
|
|
|
|
|
|
|
|
|
|
class Square(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, x):
|
|
|
|
|
y = x ** 2
|
|
|
|
|
ctx.save_for_backward(x)
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
|
assert isinstance(grad_output, LoggingTensor)
|
2021-08-12 18:39:31 +00:00
|
|
|
x, = ctx.saved_tensors
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
assert isinstance(x, LoggingTensor)
|
|
|
|
|
escape[0] = x
|
|
|
|
|
return grad_output * 2 * x
|
|
|
|
|
|
|
|
|
|
with capture_logs() as logs:
|
2022-02-14 20:05:41 +00:00
|
|
|
x = LoggingTensor(torch.ones(1), requires_grad=True)
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
log_input("x", x)
|
|
|
|
|
x.grad = LoggingTensor(torch.zeros(1))
|
|
|
|
|
log_input("x.grad", x.grad)
|
|
|
|
|
y = Square.apply(x)
|
|
|
|
|
grad_output = LoggingTensor(torch.ones(1))
|
|
|
|
|
log_input("grad_output", grad_output)
|
|
|
|
|
y.backward(grad_output)
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
self.assertEqual(escape[0], x)
|
|
|
|
|
self.assertEqual(escape[0]._version, x._version)
|
|
|
|
|
# TODO: figure out why x.requires_grad = False doesn't
|
|
|
|
|
# trigger an error for LoggingTensor
|
|
|
|
|
x.add_(2)
|
|
|
|
|
self.assertEqual(escape[0], x)
|
|
|
|
|
# TODO: figure out why this is broken
|
|
|
|
|
# self.assertEqual(escape[0]._version, x._version)
|
|
|
|
|
|
2021-08-12 18:39:31 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[1] = input('x')
|
|
|
|
|
$1: f32[1] = input('x.grad')
|
|
|
|
|
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
|
|
|
|
|
$3: f32[1] = input('grad_output')
|
|
|
|
|
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
|
|
|
|
|
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
|
|
|
|
|
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
|
2021-09-10 20:07:37 +00:00
|
|
|
def test_subclass_creation(self):
|
|
|
|
|
# Make sure these statements runs without error
|
|
|
|
|
# In particular checking that when internal detach returns
|
|
|
|
|
# subclasses, these are cleanly overwritten.
|
|
|
|
|
class Foo(torch.Tensor):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
|
|
|
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
|
|
|
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
|
2021-09-22 18:09:11 +00:00
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
2021-09-10 20:07:37 +00:00
|
|
|
Foo(LoggingTensor(torch.rand(2)))
|
|
|
|
|
|
2021-09-22 18:09:11 +00:00
|
|
|
with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
|
|
|
|
|
torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
|
2021-09-10 20:07:37 +00:00
|
|
|
|
2021-09-21 17:45:44 +00:00
|
|
|
def test_new_ones(self) -> None:
|
|
|
|
|
class MyTensor(torch.Tensor):
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
return MyTensor(3)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
|
|
|
|
|
|
2021-10-13 20:32:02 +00:00
|
|
|
def test_like(self) -> None:
|
|
|
|
|
class MyTensor(torch.Tensor):
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
return MyTensor(3)
|
|
|
|
|
|
|
|
|
|
for f in ["empty", "ones", "rand", "randn", "zeros"]:
|
|
|
|
|
f_name = f + "_like"
|
|
|
|
|
self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor)
|
|
|
|
|
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
|
|
|
|
|
|
reorder proxy / fake modes so they always run last (#104482)
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates:
(1) TLS changes in `TorchDispatchModeTLS.h/cpp`.
I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack).
`TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes.
`TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes".
`TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes).
There were also some mild codegen changes to support the new enum
(2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py`
The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`.
I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly.
(2) dispatching logic in `python_arg_parser.cpp`
This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now:
```
(a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last)
(b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses)
(c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses)
```
Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass.
How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key.
(3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends).
----- original description below -----
The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit
Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that.
**(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode**
This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes.
I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set.
This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler.
**(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()**
This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs.
**(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's**
Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active.
**(4) Allow traceable tensor subclasses to access storages from python**
This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable"
**(5) Fixed subclass fakeification**
@wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct.
**(6) make tensor subclasses resizeable**
Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`.
**(7) Added a basic DoubleTensor subclass for testing**
I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482
Approved by: https://github.com/ezyang
ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
|
|
|
def test_make_fx_with_subclass(self) -> None:
|
|
|
|
|
def f(x, y):
|
|
|
|
|
# Returns (TwoTensor, Tensor)
|
|
|
|
|
return x * y, y + y
|
|
|
|
|
x_a = torch.zeros(4)
|
|
|
|
|
x_b = torch.zeros(4)
|
|
|
|
|
y = torch.ones(4)
|
|
|
|
|
|
|
|
|
|
# make_fx() is not responsible for unwrapping tensor subclass inputs,
|
|
|
|
|
# so we do it manually here.
|
|
|
|
|
# Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
|
|
|
|
|
# convention as f(*args). Unwrapping tensor subclass inputs can potentially change
|
|
|
|
|
# the number of input args to the graph, breaking that assumption
|
|
|
|
|
def f_to_trace(x_a, x_b, y):
|
|
|
|
|
x = TwoTensor(x_a, x_b)
|
|
|
|
|
out1, out2 = f(x, y)
|
|
|
|
|
out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
|
|
|
|
|
return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)
|
|
|
|
|
fx_g = make_fx(f_to_trace, tracing_mode='fake')(x_a, x_b, y)
|
|
|
|
|
self.assertExpectedInline(fx_g.code, """\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x_a_1, x_b_1, y_1):
|
|
|
|
|
mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None
|
|
|
|
|
mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None
|
|
|
|
|
add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None
|
|
|
|
|
return (mul, mul_1, add)
|
|
|
|
|
""")
|
|
|
|
|
|
2024-02-21 16:59:06 +00:00
|
|
|
# See https://github.com/pytorch/pytorch/issues/117794
|
|
|
|
|
def test_return_and_correct_aliasing_gives_correct_stride(self):
|
|
|
|
|
t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
|
|
|
|
|
x = torch.randn(2, 2)
|
|
|
|
|
# slicing should result in the same stride for TwoTensor as a dense tensor would give
|
|
|
|
|
self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
|
|
|
|
|
|
2021-11-18 15:06:26 +00:00
|
|
|
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
|
|
|
|
|
class WrapperTensor(torch.Tensor):
|
|
|
|
|
elem: torch.Tensor
|
|
|
|
|
|
|
|
|
|
__slots__ = ['elem']
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
|
|
|
cls, elem.size(),
|
|
|
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
|
|
|
device=elem.device, requires_grad=elem.requires_grad,
|
|
|
|
|
strides=elem.stride(), storage_offset=elem.storage_offset())
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
raise RuntimeError("NYI")
|
|
|
|
|
|
|
|
|
|
# non-contiguous strides, non-zero storage offset
|
|
|
|
|
x = torch.randn(4, 6).t().diagonal(offset=2)
|
|
|
|
|
y = WrapperTensor(x)
|
|
|
|
|
self.assertEqual(y.size(), x.size())
|
|
|
|
|
self.assertEqual(y.stride(), x.stride())
|
|
|
|
|
self.assertEqual(y.storage_offset(), x.storage_offset())
|
|
|
|
|
|
2022-02-24 15:59:39 +00:00
|
|
|
def test_wrapper_subclass_serializes(self) -> None:
|
|
|
|
|
with tempfile.TemporaryFile() as f:
|
|
|
|
|
x = LoggingTensor(torch.randn(3))
|
|
|
|
|
torch.save(x, f)
|
|
|
|
|
f.seek(0)
|
|
|
|
|
x_loaded = torch.load(f)
|
|
|
|
|
self.assertTrue(type(x_loaded) is type(x))
|
|
|
|
|
self.assertEqual(x.elem, x_loaded.elem)
|
|
|
|
|
self.assertFalse(x is x_loaded)
|
|
|
|
|
|
|
|
|
|
def test_deepcopy_wrapper_subclass(self) -> None:
|
|
|
|
|
x = LoggingTensor(torch.randn(3))
|
|
|
|
|
x_copy = deepcopy(x)
|
|
|
|
|
self.assertTrue(type(x_copy) is type(x))
|
|
|
|
|
self.assertEqual(x.elem, x_copy.elem)
|
|
|
|
|
self.assertFalse(x is x_copy)
|
|
|
|
|
|
|
|
|
|
def test_deepcopy_wrapper_subclass_with_clone_returning_different_type(self) -> None:
|
|
|
|
|
|
|
|
|
|
class MyWrapperTensor(torch.Tensor):
|
|
|
|
|
elem: torch.Tensor
|
|
|
|
|
|
|
|
|
|
__slots__ = ['elem']
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
|
|
|
cls, elem.size(),
|
|
|
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
|
|
|
device=elem.device, requires_grad=elem.requires_grad,
|
|
|
|
|
strides=elem.stride(), storage_offset=elem.storage_offset())
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
2022-03-07 22:32:41 +00:00
|
|
|
if func.overloadpacket.__name__ == "clone":
|
2022-02-24 15:59:39 +00:00
|
|
|
# Return a plain tensor from clone().
|
|
|
|
|
return args[0].elem.clone()
|
|
|
|
|
raise RuntimeError("NYI")
|
|
|
|
|
|
|
|
|
|
# NB: The default Tensor.__torch_function__ implementation called for deepcopy
|
|
|
|
|
# disables __torch_function__ by the time we get to clone(), so there is no need to
|
|
|
|
|
# explicitly disable __torch_function__ for this subclass.
|
|
|
|
|
|
|
|
|
|
x = MyWrapperTensor(torch.randn(3))
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
|
|
|
"for which cloning returns another instance of the same subclass"):
|
|
|
|
|
x_copy = deepcopy(x)
|
|
|
|
|
|
|
|
|
|
def test_deepcopy_non_wrapper_subclass(self) -> None:
|
|
|
|
|
|
|
|
|
|
# Ensure correct error is thrown for common error cases.
|
|
|
|
|
class SubTensorError1(torch.Tensor):
|
|
|
|
|
# Default implementation of new_empty() returns a plain tensor.
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class SubTensorError2(torch.Tensor):
|
|
|
|
|
# new_empty() incorrectly returns a different type (i.e. a plain tensor).
|
|
|
|
|
def new_empty(self, shape):
|
|
|
|
|
return torch.Tensor(shape)
|
|
|
|
|
|
|
|
|
|
for error_cls in [SubTensorError1, SubTensorError2]:
|
|
|
|
|
x = error_cls(3)
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
|
|
|
"for which that function returns another instance of the same subclass"):
|
|
|
|
|
x_copy = deepcopy(x)
|
|
|
|
|
|
|
|
|
|
# Ensure a correctly implemented new_empty() causes deepcopy() to work.
|
|
|
|
|
class SubTensorSuccess(torch.Tensor):
|
|
|
|
|
def new_empty(self, shape):
|
|
|
|
|
return type(self)(shape)
|
|
|
|
|
|
|
|
|
|
x = SubTensorSuccess(3)
|
|
|
|
|
x_copy = deepcopy(x)
|
|
|
|
|
self.assertIs(type(x_copy), type(x))
|
|
|
|
|
|
2023-09-08 13:33:07 +00:00
|
|
|
def test_wrapper_subclass_extra_dispatch_keys(self) -> None:
|
|
|
|
|
class ExtraKeysTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
# NB: only the non-kwarg overload of _make_wrapper_subclass supports
|
|
|
|
|
# extra dispatch keys. We probably want to unify the two APIs
|
|
|
|
|
# in the future.
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
|
|
|
cls, elem.size(), elem.stride(), elem.storage_offset(),
|
|
|
|
|
torch.contiguous_format,
|
|
|
|
|
elem.dtype, elem.layout,
|
|
|
|
|
elem.device, False, False, None, False, False,
|
|
|
|
|
DispatchKeySet(DispatchKey.NestedTensor))
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
x = ExtraKeysTensor(torch.randn(3))
|
|
|
|
|
self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor))
|
|
|
|
|
self.assertFalse(torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor))
|
|
|
|
|
|
2021-11-29 16:29:05 +00:00
|
|
|
def test_index_put_where_only_index_is_subclass(self) -> None:
|
|
|
|
|
called_funcs = []
|
|
|
|
|
|
|
|
|
|
class MyTensor(torch.Tensor):
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
elem: torch.Tensor
|
|
|
|
|
__slots__ = ['elem']
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
|
|
|
cls, elem.size(),
|
|
|
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
|
|
|
device=elem.device, requires_grad=elem.requires_grad
|
|
|
|
|
)
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
called_funcs.append(func)
|
|
|
|
|
return MyTensor(torch.tensor(3))
|
|
|
|
|
|
|
|
|
|
x = torch.randn(3, 3)
|
|
|
|
|
idxs = (MyTensor(torch.tensor(0)),)
|
|
|
|
|
v = torch.randn(1)
|
|
|
|
|
res = x.index_put_(idxs, v)
|
2022-03-07 22:32:41 +00:00
|
|
|
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
|
2021-11-29 16:29:05 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
def test_torch_dispatch_mode_basic(self) -> None:
|
2022-05-24 01:23:24 +00:00
|
|
|
with capture_logs(is_mode=True) as logs:
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with LoggingTensorMode():
|
2022-05-24 01:23:24 +00:00
|
|
|
torch.empty([])
|
2022-08-06 01:18:40 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), """\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
2022-05-24 01:23:24 +00:00
|
|
|
with capture_logs(is_mode=True) as logs:
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with LoggingTensorMode():
|
2022-05-24 01:23:24 +00:00
|
|
|
x + y
|
2023-06-21 16:12:52 +00:00
|
|
|
self.assertExpectedInline('\n'.join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)""")
|
2022-05-24 01:23:24 +00:00
|
|
|
|
|
|
|
|
def test_nested_push_logging_tensor_mode(self):
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
with capture_logs(is_mode=True) as logs:
|
2022-07-22 15:10:52 +00:00
|
|
|
with LoggingTensorMode():
|
|
|
|
|
with LoggingTensorMode():
|
2022-05-24 01:23:24 +00:00
|
|
|
torch.empty([])
|
|
|
|
|
x + y
|
|
|
|
|
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs), """\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
|
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
|
|
|
|
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
|
|
|
|
|
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
2022-05-24 01:23:24 +00:00
|
|
|
|
|
|
|
|
def test_capture_logs_with_torch_dispatch_mode(self):
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
with capture_logs_with_logging_tensor_mode() as logs:
|
|
|
|
|
torch.empty([])
|
|
|
|
|
x + y
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs), """\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
|
|
|
|
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
2022-05-24 01:23:24 +00:00
|
|
|
|
|
|
|
|
x = torch.randn([])
|
|
|
|
|
y = torch.randn([])
|
|
|
|
|
|
|
|
|
|
with capture_logs_with_logging_tensor_mode() as logs1:
|
|
|
|
|
with capture_logs_with_logging_tensor_mode() as logs2:
|
|
|
|
|
torch.empty([])
|
|
|
|
|
x + y
|
|
|
|
|
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs2), """\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
|
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
|
|
|
|
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
|
|
|
|
|
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
2022-05-24 01:23:24 +00:00
|
|
|
|
|
|
|
|
self.assertEqual(logs1, logs2)
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
def test_torch_dispatch_mode_subclass_priority(self) -> None:
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
class ErrorA(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class ErrorB(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class A(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with AMode():
|
|
|
|
|
raise ErrorA
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
|
|
|
|
|
class B(A):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with BMode():
|
|
|
|
|
func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
class AMode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorA
|
|
|
|
|
|
|
|
|
|
class BMode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
raise ErrorB
|
|
|
|
|
|
|
|
|
|
a = A(torch.empty(1))
|
|
|
|
|
b = B(torch.empty(1))
|
|
|
|
|
with self.assertRaises(ErrorA):
|
|
|
|
|
a + a
|
|
|
|
|
with self.assertRaises(ErrorB):
|
2022-05-03 17:13:03 +00:00
|
|
|
a + b
|
|
|
|
|
|
|
|
|
|
# B has precedence over A due to the subclass relationship yet
|
|
|
|
|
# modes take precedence over arguments
|
|
|
|
|
with self.assertRaises(ErrorA):
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with AMode():
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
b + b
|
|
|
|
|
with self.assertRaises(ErrorB):
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with BMode():
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
a + a
|
|
|
|
|
with self.assertRaises(ErrorB):
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with BMode():
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
a + b
|
|
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
def test_mode_with_make_subclass(self):
|
|
|
|
|
class SubTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
class BasicMode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
x = torch.randn(3)
|
|
|
|
|
with BasicMode():
|
|
|
|
|
y = SubTensor(x)
|
|
|
|
|
self.assertIsInstance(y, SubTensor)
|
|
|
|
|
|
|
|
|
|
def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
|
2022-05-24 01:23:24 +00:00
|
|
|
with capture_logs(is_mode=True) as logs1:
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with LoggingTensorMode():
|
2022-05-24 01:23:24 +00:00
|
|
|
torch.ones([2, 3])
|
|
|
|
|
with no_dispatch():
|
|
|
|
|
torch.ones([2, 3])
|
|
|
|
|
with capture_logs(is_mode=True) as logs2:
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with LoggingTensorMode():
|
2022-05-24 01:23:24 +00:00
|
|
|
torch.ones([2, 3])
|
|
|
|
|
self.assertEqual(logs1, logs2)
|
[Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D30698082
Pulled By: zou3519
fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 16:00:34 +00:00
|
|
|
|
2022-08-16 03:03:12 +00:00
|
|
|
def test_shallow_copy_and_detach(self) -> None:
|
|
|
|
|
seen = set()
|
|
|
|
|
test_case = self
|
|
|
|
|
|
|
|
|
|
class TestMode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
tree_map_only(torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs))
|
|
|
|
|
if kwargs is None:
|
|
|
|
|
kwargs = {}
|
|
|
|
|
r = func(*args, **kwargs)
|
|
|
|
|
tree_map_only(torch.Tensor, lambda t: seen.add(t), r)
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
with TestMode():
|
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
|
|
|
loss = (x * x).sum()
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
2022-05-03 17:13:03 +00:00
|
|
|
def test_exception_handling(self):
|
|
|
|
|
class A(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
class AMode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
2022-05-03 17:13:03 +00:00
|
|
|
if func.__name__ == 'randn.default':
|
|
|
|
|
raise RuntimeError()
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
return A(torch.zeros(()))
|
2022-05-03 17:13:03 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with AMode():
|
2022-05-03 17:13:03 +00:00
|
|
|
try:
|
|
|
|
|
torch.randn(())
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
pass
|
|
|
|
|
self.assertTrue(isinstance(torch.zeros(()), A))
|
|
|
|
|
|
2022-06-01 18:47:38 +00:00
|
|
|
def test_with_mode_created_separately(self):
|
|
|
|
|
class ErrorA(RuntimeError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class A(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorA()
|
|
|
|
|
|
|
|
|
|
x = A()
|
|
|
|
|
with self.assertRaises(ErrorA):
|
|
|
|
|
with x:
|
|
|
|
|
torch.empty([])
|
|
|
|
|
|
|
|
|
|
def test_with_nested_modes(self):
|
|
|
|
|
class ErrorA(RuntimeError):
|
|
|
|
|
def __init__(self, msg):
|
2023-05-11 23:57:25 +00:00
|
|
|
super().__init__(msg)
|
2022-06-01 18:47:38 +00:00
|
|
|
|
|
|
|
|
class A(TorchDispatchMode):
|
|
|
|
|
def __init__(self, msg):
|
|
|
|
|
self.msg = msg
|
|
|
|
|
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
raise ErrorA(self.msg)
|
|
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ErrorA, "layer2"):
|
|
|
|
|
with A("layer1"):
|
2023-01-12 16:44:29 +00:00
|
|
|
with A("layer2"):
|
2022-06-01 18:47:38 +00:00
|
|
|
torch.empty([])
|
|
|
|
|
|
2022-06-17 18:47:18 +00:00
|
|
|
def test_make_subclass_with_modes(self):
|
2022-06-01 18:47:38 +00:00
|
|
|
class ModeTensor(torch.Tensor):
|
|
|
|
|
def __new__(cls, elem, mode):
|
|
|
|
|
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
r.elem = elem
|
|
|
|
|
r.mode = mode
|
|
|
|
|
return r
|
|
|
|
|
|
2022-06-17 18:47:18 +00:00
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
raise NotImplementedError("Shouldn't be here")
|
2022-06-01 18:47:38 +00:00
|
|
|
|
|
|
|
|
class Mode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
def unwrap(e):
|
|
|
|
|
if isinstance(e, ModeTensor):
|
|
|
|
|
return e.elem
|
|
|
|
|
else:
|
|
|
|
|
return e
|
|
|
|
|
|
|
|
|
|
def wrap(t):
|
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
|
|
|
return ModeTensor(t, self)
|
|
|
|
|
else:
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))
|
|
|
|
|
|
2022-06-17 18:47:18 +00:00
|
|
|
class BasicMode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
2022-06-01 18:47:38 +00:00
|
|
|
x = torch.tensor(4.)
|
|
|
|
|
with Mode():
|
|
|
|
|
y = x + x
|
|
|
|
|
z = y + y
|
|
|
|
|
self.assertIsInstance(y, ModeTensor)
|
|
|
|
|
self.assertIsInstance(z, ModeTensor)
|
|
|
|
|
|
2022-06-17 18:47:18 +00:00
|
|
|
with Mode():
|
|
|
|
|
with BasicMode(): # we can't nest two modes that call make_subclass because it only accepts vanilla tensors
|
|
|
|
|
y = x + x
|
|
|
|
|
z = y + y
|
|
|
|
|
self.assertIsInstance(y, ModeTensor)
|
|
|
|
|
self.assertIsInstance(z, ModeTensor)
|
|
|
|
|
|
|
|
|
|
assert self.assertRaisesRegex(RuntimeError, "subclass Mode but.* associated to a python object of type Mode")
|
|
|
|
|
|
2022-07-06 20:18:50 +00:00
|
|
|
def test_notimplemented_mode(self):
|
|
|
|
|
sub_count = 0
|
|
|
|
|
|
|
|
|
|
class PoliteMode(TorchDispatchMode):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.pre_count = 0
|
|
|
|
|
self.post_count = 0
|
|
|
|
|
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
self.pre_count += 1
|
|
|
|
|
if any(t is not torch.Tensor for t in types):
|
|
|
|
|
return NotImplemented
|
|
|
|
|
self.post_count += 1
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
class SubTensor(torch.Tensor):
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
nonlocal sub_count
|
|
|
|
|
sub_count += 1
|
|
|
|
|
|
|
|
|
|
def unwrap(t):
|
|
|
|
|
if isinstance(t, SubTensor):
|
|
|
|
|
return t.elem
|
|
|
|
|
else:
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
|
|
|
|
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
a = SubTensor(torch.randn(2))
|
2022-07-12 20:14:48 +00:00
|
|
|
with PoliteMode() as mode:
|
2022-07-06 20:18:50 +00:00
|
|
|
a.abs()
|
|
|
|
|
|
|
|
|
|
self.assertEqual(mode.pre_count, 2)
|
|
|
|
|
self.assertEqual(mode.post_count, 1)
|
|
|
|
|
self.assertEqual(sub_count, 1)
|
|
|
|
|
|
|
|
|
|
# make sure this doesn't error
|
|
|
|
|
with PoliteMode():
|
|
|
|
|
with PoliteMode():
|
|
|
|
|
a.abs()
|
|
|
|
|
|
2023-01-12 16:44:29 +00:00
|
|
|
def test_nesting_same_mode(self):
|
|
|
|
|
# If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.
|
2022-06-17 18:47:18 +00:00
|
|
|
|
2023-01-12 16:44:29 +00:00
|
|
|
with capture_logs(is_mode=True) as logs:
|
|
|
|
|
with LoggingTensorMode() as reenabled:
|
|
|
|
|
with reenabled:
|
|
|
|
|
torch.empty([])
|
|
|
|
|
self.assertExpectedInline('\n'.join(logs), """\
|
2023-06-21 16:12:52 +00:00
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
|
|
|
|
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
2022-06-01 18:47:38 +00:00
|
|
|
|
2022-05-03 17:13:03 +00:00
|
|
|
|
2022-06-03 21:26:18 +00:00
|
|
|
def test_error_using_class_method_on_mode(self):
|
|
|
|
|
class A(TorchDispatchMode):
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
return func(args, kwargs)
|
|
|
|
|
|
|
|
|
|
x = torch.tensor(5.)
|
2022-10-19 14:36:40 +00:00
|
|
|
with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"):
|
2022-06-03 21:26:18 +00:00
|
|
|
with A():
|
|
|
|
|
x + x
|
|
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
def test_get_cur_mode(self):
|
|
|
|
|
class A(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
2022-06-10 14:23:25 +00:00
|
|
|
pass
|
|
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
self.assertEqual(_get_current_dispatch_mode(), None)
|
2022-06-10 14:23:26 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with A() as mode1:
|
|
|
|
|
self.assertEqual(_get_current_dispatch_mode(), mode1)
|
2022-06-10 14:23:26 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with mode1:
|
|
|
|
|
with A() as mode2:
|
|
|
|
|
self.assertEqual(_get_current_dispatch_mode(), mode2)
|
2022-06-10 14:23:26 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
def test_get_mode_stack(self):
|
|
|
|
|
class A(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
2022-06-10 14:23:26 +00:00
|
|
|
pass
|
|
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
self.assertEqual(_get_current_dispatch_mode_stack(), [])
|
2022-06-10 14:23:26 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with A() as mode1:
|
|
|
|
|
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])
|
2022-06-10 14:23:26 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
with mode1:
|
|
|
|
|
with A() as mode2:
|
|
|
|
|
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])
|
2022-06-10 14:23:26 +00:00
|
|
|
|
|
|
|
|
def test_all_same_mode(self):
|
|
|
|
|
x = LoggingTensorMode()
|
|
|
|
|
y = LoggingTensorMode()
|
|
|
|
|
self.assertTrue(all_same_mode([x, x, x]))
|
|
|
|
|
self.assertFalse(all_same_mode([x, None]))
|
|
|
|
|
self.assertFalse(all_same_mode([x, y]))
|
|
|
|
|
|
2022-05-02 20:06:43 +00:00
|
|
|
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
|
2021-10-13 20:49:31 +00:00
|
|
|
x = LoggingTensor(torch.tensor([2.0, 3.0]))
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
|
|
|
|
x.tolist()
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
|
|
|
|
x.numpy()
|
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
|
self.assertEqual(x, None)
|
|
|
|
|
|
2023-04-21 07:17:19 +00:00
|
|
|
def test_record_stream(self) -> None:
|
|
|
|
|
class TestMode(TorchDispatchMode):
|
|
|
|
|
def __init__(self, testcase):
|
|
|
|
|
self.testcase = testcase
|
|
|
|
|
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
self.testcase.assertEqual(func.name(), "aten::record_stream")
|
|
|
|
|
self.testcase.assertIsInstance(args[0], torch.Tensor)
|
|
|
|
|
self.testcase.assertIsInstance(args[1], torch.Stream)
|
|
|
|
|
self.testcase.assertEqual(args[1].stream_id, 1)
|
|
|
|
|
self.testcase.assertEqual(args[1].device_index, 2)
|
|
|
|
|
self.testcase.assertEqual(args[1].device_type, 3)
|
|
|
|
|
|
|
|
|
|
t = torch.tensor(5.)
|
|
|
|
|
s = torch.Stream(stream_id=1, device_index=2, device_type=3)
|
|
|
|
|
with TestMode(self):
|
|
|
|
|
t.record_stream(s)
|
|
|
|
|
|
|
|
|
|
def test_return_stream(self) -> None:
|
2024-02-12 23:30:08 +00:00
|
|
|
with _scoped_library("test_return_stream", "DEF") as l_def:
|
|
|
|
|
l_def.define("return_stream(Tensor self) -> Stream")
|
|
|
|
|
with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
|
|
|
|
|
l_impl.impl("return_stream",
|
|
|
|
|
lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2))
|
|
|
|
|
|
|
|
|
|
class TestMode(TorchDispatchMode):
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
|
|
|
return torch.Stream(stream_id=1, device_index=2, device_type=3)
|
|
|
|
|
|
|
|
|
|
t = torch.tensor(5.)
|
|
|
|
|
s = torch.ops.test_return_stream.return_stream(t)
|
|
|
|
|
self.assertIsInstance(s, torch.Stream)
|
|
|
|
|
self.assertEqual(s.stream_id, 0)
|
|
|
|
|
self.assertEqual(s.device_index, 1)
|
|
|
|
|
self.assertEqual(s.device_type, 2)
|
|
|
|
|
|
|
|
|
|
with TestMode():
|
|
|
|
|
s = torch.ops.test_return_stream.return_stream(t)
|
|
|
|
|
self.assertIsInstance(s, torch.Stream)
|
|
|
|
|
self.assertEqual(s.stream_id, 1)
|
|
|
|
|
self.assertEqual(s.device_index, 2)
|
|
|
|
|
self.assertEqual(s.device_type, 3)
|
2023-04-21 07:17:19 +00:00
|
|
|
|
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}
This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily
Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup
### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like
```python
## PRE-PR UX
def f(mode):
with mode.restore(): # user needs to understand this restore thing?
...
with Mode() as m:
pass
f(m)
```
Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
with mode:
...
f(Mode())
```
** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
|
|
|
def test_subclass_autograd_device_check(self) -> None:
|
2021-12-02 15:45:35 +00:00
|
|
|
class NonWrapperSubclass(torch.Tensor):
|
2021-09-24 15:37:41 +00:00
|
|
|
elem: torch.Tensor
|
|
|
|
|
|
|
|
|
|
__slots__ = ['elem']
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
# Wrong device here!
|
|
|
|
|
r = torch.Tensor._make_subclass(cls, elem.to("meta"), elem.requires_grad)
|
|
|
|
|
# ...the real tensor is held as an element on the tensor.
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
def unwrap(e):
|
2021-12-02 15:45:35 +00:00
|
|
|
return e.elem if isinstance(e, NonWrapperSubclass) else e
|
2021-09-24 15:37:41 +00:00
|
|
|
|
|
|
|
|
def wrap(e):
|
2021-12-02 15:45:35 +00:00
|
|
|
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
|
2021-09-24 15:37:41 +00:00
|
|
|
|
2022-05-03 17:13:03 +00:00
|
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
2021-12-02 15:45:35 +00:00
|
|
|
logging.getLogger("NonWrapperSubclass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
2021-09-24 15:37:41 +00:00
|
|
|
return rs
|
|
|
|
|
|
2021-12-02 15:45:35 +00:00
|
|
|
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
|
2021-09-24 15:37:41 +00:00
|
|
|
y = torch.randn(2, requires_grad=True)
|
2021-09-22 17:57:59 +00:00
|
|
|
z = x * y
|
2021-12-02 15:45:35 +00:00
|
|
|
self.assertIsInstance(z, NonWrapperSubclass)
|
2021-09-24 15:37:41 +00:00
|
|
|
z.sum().backward(torch.tensor(1))
|
2021-09-22 17:57:59 +00:00
|
|
|
self.assertEqual(x.grad, y)
|
2021-09-24 15:37:41 +00:00
|
|
|
self.assertEqual(y.grad, x)
|
2021-09-22 17:57:59 +00:00
|
|
|
|
2021-12-02 15:45:35 +00:00
|
|
|
def test_none_wrapping(self):
|
|
|
|
|
# A Tensor subclass that returns None when doing add
|
|
|
|
|
# See LoggingTensor above for more details on the subclass
|
|
|
|
|
class SubclassWithNone(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
|
|
|
cls, elem.size(),
|
|
|
|
|
dtype=elem.dtype, layout=elem.layout,
|
|
|
|
|
device=elem.device, requires_grad=elem.requires_grad
|
|
|
|
|
)
|
|
|
|
|
r.elem = elem
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
def unwrap(e):
|
|
|
|
|
return e.elem if isinstance(e, SubclassWithNone) else e
|
|
|
|
|
|
|
|
|
|
def wrap(e):
|
|
|
|
|
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
|
|
|
|
|
|
2022-05-03 17:13:03 +00:00
|
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
2022-03-07 22:32:41 +00:00
|
|
|
if func.overloadpacket.__name__ == "add":
|
2021-12-02 15:45:35 +00:00
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
return rs
|
|
|
|
|
|
|
|
|
|
x = SubclassWithNone(torch.rand(2))
|
|
|
|
|
# Make sure both run without error
|
|
|
|
|
self.assertIsInstance(x * 2, SubclassWithNone)
|
|
|
|
|
self.assertIsNone(x + 2)
|
|
|
|
|
|
|
|
|
|
x.requires_grad_()
|
|
|
|
|
out = x.acos().sum()
|
|
|
|
|
|
|
|
|
|
# The backward of acos does add then rsqrt so here we make sure that the
|
|
|
|
|
# undefined Tensor generated by the user code is nicely handled.
|
|
|
|
|
# If acos formula changes in the future, this can be replaced by any other
|
|
|
|
|
# function that does add then something in the backward in a composite way
|
2021-12-08 16:43:33 +00:00
|
|
|
with self.assertRaisesRegex(RuntimeError, "but got None"):
|
2021-12-02 15:45:35 +00:00
|
|
|
out.backward()
|
|
|
|
|
|
2022-01-20 02:07:36 +00:00
|
|
|
def test_storage_can_be_converted_to_python_object(self):
|
2022-05-24 01:23:24 +00:00
|
|
|
s = torch.Storage()
|
|
|
|
|
z = LoggingTensor(torch.empty([]))
|
|
|
|
|
z.set_(s)
|
2022-01-20 02:07:36 +00:00
|
|
|
|
2022-02-14 20:05:41 +00:00
|
|
|
def test_autograd_in_attr(self):
|
|
|
|
|
# We want the wrapped Tensor to require gradients!
|
|
|
|
|
true_t = torch.rand(2, requires_grad=True)
|
2022-04-20 14:01:17 +00:00
|
|
|
t = LoggingTensorReentrant(true_t)
|
2022-02-14 20:05:41 +00:00
|
|
|
|
|
|
|
|
out = t + 2
|
|
|
|
|
|
|
|
|
|
self.assertFalse(out.requires_grad)
|
|
|
|
|
self.assertIsNone(out.grad_fn)
|
|
|
|
|
|
2022-02-15 18:54:32 +00:00
|
|
|
self.assertTrue(out.elem.requires_grad)
|
|
|
|
|
self.assertIsNotNone(out.elem.grad_fn)
|
2022-02-14 20:05:41 +00:00
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
2022-02-15 18:54:32 +00:00
|
|
|
out.sum().backward()
|
2022-02-14 20:05:41 +00:00
|
|
|
|
2022-02-15 18:54:32 +00:00
|
|
|
out.elem.sum().backward()
|
2022-02-14 20:05:41 +00:00
|
|
|
|
|
|
|
|
self.assertIsNone(t.grad)
|
2022-02-15 18:54:32 +00:00
|
|
|
self.assertIsNotNone(t.elem.grad)
|
2022-02-14 20:05:41 +00:00
|
|
|
|
2022-03-03 04:12:21 +00:00
|
|
|
def test_dispatch_super_call(self):
|
|
|
|
|
called = []
|
|
|
|
|
|
|
|
|
|
class SubTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem)
|
|
|
|
|
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
called.append(func)
|
|
|
|
|
return super().__torch_dispatch__(func, types, args, kwargs)
|
|
|
|
|
|
|
|
|
|
x = torch.randn(2)
|
|
|
|
|
y = torch.randn(2)
|
|
|
|
|
self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
|
2022-03-07 22:32:41 +00:00
|
|
|
self.assertEqual(called, [torch.ops.aten.add.Tensor])
|
2022-03-03 04:12:21 +00:00
|
|
|
|
2022-03-23 18:52:42 +00:00
|
|
|
def test_dispatch_super_call_list_arg(self):
|
|
|
|
|
called = []
|
|
|
|
|
|
|
|
|
|
class SubTensorWithListArg(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem)
|
|
|
|
|
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
called.append(func)
|
|
|
|
|
return super().__torch_dispatch__(func, types, list(args), kwargs)
|
|
|
|
|
|
|
|
|
|
x = torch.randn(2)
|
|
|
|
|
self.assertEqual(SubTensorWithListArg(x).neg(), x.neg())
|
|
|
|
|
self.assertEqual(called, [torch.ops.aten.neg.default])
|
|
|
|
|
|
2022-03-03 04:12:21 +00:00
|
|
|
def test_dispatch_super_dont_autograd(self):
|
|
|
|
|
called = []
|
|
|
|
|
|
|
|
|
|
class SubTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
|
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
called.append(func)
|
|
|
|
|
# This argument still requires grad because it was passed
|
|
|
|
|
# through directly...
|
|
|
|
|
self.assertTrue(args[0].requires_grad)
|
|
|
|
|
r = super().__torch_dispatch__(func, types, args, kwargs)
|
|
|
|
|
# But the output better not require grad, because that means
|
|
|
|
|
# you did autograd again in torch dispatch (oops)
|
|
|
|
|
self.assertFalse(r.requires_grad)
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
x = SubTensor(torch.randn(2, requires_grad=True))
|
|
|
|
|
x.neg()
|
2022-03-07 22:32:41 +00:00
|
|
|
self.assertEqual(called, [torch.ops.aten.neg.default])
|
2022-03-03 04:12:21 +00:00
|
|
|
|
2022-04-15 03:21:50 +00:00
|
|
|
def test_set_data(self):
|
|
|
|
|
called = 0
|
|
|
|
|
|
|
|
|
|
class SubTensor(torch.Tensor):
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
nonlocal called
|
|
|
|
|
called += 1
|
|
|
|
|
return super().__torch_dispatch__(func, types, args, kwargs)
|
|
|
|
|
|
|
|
|
|
x = SubTensor(torch.empty(2))
|
|
|
|
|
x.data
|
|
|
|
|
self.assertEqual(called, 1)
|
|
|
|
|
x.data = torch.empty(2)
|
|
|
|
|
self.assertEqual(called, 1)
|
|
|
|
|
x.data
|
|
|
|
|
self.assertEqual(called, 2)
|
|
|
|
|
self.assertIs(type(x), SubTensor)
|
|
|
|
|
x.set_(torch.empty(2))
|
|
|
|
|
self.assertEqual(called, 3)
|
|
|
|
|
x.data
|
|
|
|
|
self.assertEqual(called, 4)
|
|
|
|
|
self.assertIs(type(x), SubTensor)
|
|
|
|
|
|
2022-03-09 15:38:00 +00:00
|
|
|
def test_construct_int_tensor(self):
|
|
|
|
|
class SubTensor(torch.Tensor):
|
|
|
|
|
pass
|
|
|
|
|
# should not fail
|
|
|
|
|
SubTensor(torch.zeros(2, dtype=torch.int))
|
|
|
|
|
|
2022-02-23 14:38:58 +00:00
|
|
|
def test_multiple_ops_subclass(self):
|
|
|
|
|
# This is a Direct Subclass, don't do that!
|
|
|
|
|
class MySubclass(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, elem):
|
|
|
|
|
r = torch.Tensor._make_subclass(cls, elem)
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
with no_dispatch():
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
|
|
|
|
|
y = x.conj()
|
|
|
|
|
# Details of the bug that this tests for:
|
|
|
|
|
# Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
|
|
|
|
|
# There are a few calls to the dispatcher that are going to happen here:
|
|
|
|
|
# - call_exp: User calling exp on y
|
|
|
|
|
# - PythonTLSSnapshot: records the TLS on entry and redispatch
|
|
|
|
|
# - AutogradCPU: no input requires grad, so does nothing and redispatch
|
|
|
|
|
# - Conjugate: no special implementation for exp: use the fallback that
|
|
|
|
|
# first clone the Tensor (to materialize the conj) then redispatch
|
|
|
|
|
# - call_clone: conjugate fallback calling clone on y
|
|
|
|
|
# - PythonTLSSnapshot: records the TLS on entry and redispatch
|
|
|
|
|
# - (AutogradCPU: skipped as autograd added itself to the exclude set above)
|
|
|
|
|
# - Conjugate: special implementation for clone: just skip this key
|
|
|
|
|
# - Python: Reset the TLS based on the snapshot above and call the user implementation (this
|
|
|
|
|
# actually calls into the dispatcher again but since we disable both our keys
|
|
|
|
|
# before, not detailed here)
|
|
|
|
|
# - exit Python: restore the TLS and exit
|
|
|
|
|
# - exit Conjugate: nothing was inplace so just exit
|
|
|
|
|
# - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
|
|
|
|
|
# - Python: Reset the TLS again based on the snapshot. <- this used to fail
|
|
|
|
|
# - More steps....
|
|
|
|
|
y.exp()
|
|
|
|
|
|
2022-05-24 16:23:05 +00:00
|
|
|
@staticmethod
|
|
|
|
|
def subclass_helper(cls, data, use_wrapper_subclass, **kwargs):
|
|
|
|
|
if use_wrapper_subclass:
|
|
|
|
|
kwargs["device"] = data.device
|
|
|
|
|
kwargs["dtype"] = data.dtype
|
|
|
|
|
kwargs["layout"] = data.layout
|
|
|
|
|
kwargs["requires_grad"] = True
|
|
|
|
|
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
|
|
|
|
|
else:
|
|
|
|
|
return torch.Tensor._make_subclass(cls, data, True, **kwargs)
|
|
|
|
|
|
2022-05-19 22:52:45 +00:00
|
|
|
def test_is_contiguous_slow_path(self):
|
|
|
|
|
data = torch.randn(3, 3)
|
|
|
|
|
contiguous_data = data.clone()
|
|
|
|
|
not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
|
|
|
|
|
|
|
|
|
|
for use_wrapper_subclass in [True, False]:
|
|
|
|
|
class ExampleTensor1(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
2022-06-01 19:38:12 +00:00
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
2022-05-19 22:52:45 +00:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class ExampleTensor2(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
2022-06-01 19:38:12 +00:00
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
2022-05-19 22:52:45 +00:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
|
|
|
|
return contiguous_data.is_contiguous()
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class ExampleTensor3(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
2022-06-01 19:38:12 +00:00
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
2022-05-19 22:52:45 +00:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
|
|
|
|
return not_contiguous_data.is_contiguous()
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-23 22:58:14 +00:00
|
|
|
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
|
2022-05-19 22:52:45 +00:00
|
|
|
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
|
e.is_contiguous()
|
|
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
|
e.contiguous()
|
|
|
|
|
|
|
|
|
|
e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.is_contiguous(), True)
|
|
|
|
|
e.contiguous() # this will just return the original TensorImpl since is_contiguous = True
|
|
|
|
|
|
Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-23 22:58:14 +00:00
|
|
|
err_msg = "Multiple dispatch failed for"
|
2022-05-19 22:52:45 +00:00
|
|
|
e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.is_contiguous(), False)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
|
e.contiguous()
|
2022-02-23 14:38:58 +00:00
|
|
|
|
Consistent compute numel/contiguous strategy with SymInts (#85858)
Previously, our handling for contiguity was inconsistent in the following ways:
- is_strides_like 2d/3d and is_non_overlapping_and_dense always were computed
based on sizes_and_strides_, even if you had symbolic ints
- Furthermore, even if you set custom policy for strides, these quantities were
not overridable by subclasses
- Furthermore, we didn't even store these fields on ExtraMeta
- We duplicate implementations of compute_contiguous (plain, channels last,
channels last 3d)
- We inconsistently called refresh_numel()/refresh_contiguous(), versus
recomputing it ourselves
This factor makes a consistent strategy for all of the boolean fields, and
for numel computation. After this refactor:
- All layout boolean fields are interposable via strides policy
and can be overridden from Python; you will never access a garbage field
- All layout boolean fields are on ExtraMeta
- You can always call refresh_numel/contiguous, no matter if your Tensor is
contiguous or not
- The numel/layout boolean fields are always populated consistently with
the sizes strides fields (either on Tensor or ExtraMeta), even if you
have custom policy
- There is only one implementation of the actual computation logic
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: [D39907696](https://our.internmc.facebook.com/intern/diff/D39907696)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85858
Approved by: https://github.com/albanD
2022-09-30 17:01:35 +00:00
|
|
|
def test_fancy_strides(self):
|
|
|
|
|
calls = []
|
|
|
|
|
|
|
|
|
|
class ExampleTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, False, dispatch_sizes_strides_policy="strides")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func in [
|
|
|
|
|
torch.ops.aten.is_contiguous.default,
|
|
|
|
|
torch.ops.aten.is_contiguous.memory_format,
|
|
|
|
|
torch.ops.aten.is_strides_like_format.default,
|
|
|
|
|
torch.ops.aten.is_non_overlapping_and_dense.default,
|
|
|
|
|
torch.ops.aten.stride.default
|
|
|
|
|
]:
|
|
|
|
|
calls.append((func, list(args)[1:]))
|
|
|
|
|
return None
|
|
|
|
|
with no_dispatch():
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
e = ExampleTensor(torch.randn(2, 2))
|
|
|
|
|
self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
|
self.assertEqual(calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])])
|
|
|
|
|
calls.clear()
|
|
|
|
|
self.assertFalse(torch.ops.aten.is_strides_like_format.default(e, torch.channels_last))
|
|
|
|
|
self.assertEqual(calls, [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])])
|
|
|
|
|
calls.clear()
|
|
|
|
|
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
|
|
|
|
|
self.assertEqual(calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])])
|
|
|
|
|
|
2022-05-24 16:23:05 +00:00
|
|
|
def test_device_slowpath(self):
|
|
|
|
|
for use_wrapper_subclass in [True]:
|
|
|
|
|
class ExampleTensor1(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_device=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class ExampleTensor2(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_device=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.prim.device:
|
|
|
|
|
return torch.device('meta')
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class ExampleTensor3(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_device=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.prim.device:
|
|
|
|
|
return torch.device('meta')
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-23 22:58:14 +00:00
|
|
|
err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'"
|
2022-05-24 16:23:05 +00:00
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
|
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
e.device()
|
|
|
|
|
|
|
|
|
|
ten = torch.rand([1])
|
|
|
|
|
e = ExampleTensor2(torch.randn(3, 3, device='cpu'), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.device.type, 'meta')
|
|
|
|
|
self.assertEqual(ten.type_as(e).device.type, 'meta')
|
|
|
|
|
|
|
|
|
|
e = ExampleTensor3(torch.randn(3, 3, device='cpu'), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.device.type, 'meta')
|
|
|
|
|
self.assertEqual(ten.type_as(e).device.type, 'meta')
|
|
|
|
|
|
2022-06-02 16:45:17 +00:00
|
|
|
def test_dim_slowpath(self):
|
|
|
|
|
data = torch.randn(3, 3)
|
|
|
|
|
|
|
|
|
|
for use_wrapper_subclass in [True, False]:
|
|
|
|
|
class DimNotImplementedTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class DimImplementedTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.aten.dim:
|
|
|
|
|
return data.dim()
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-23 22:58:14 +00:00
|
|
|
err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'"
|
2022-06-02 16:45:17 +00:00
|
|
|
e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
|
e.dim()
|
|
|
|
|
|
|
|
|
|
t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(t.dim(), 2)
|
|
|
|
|
|
2022-06-02 23:54:02 +00:00
|
|
|
def test_maybe_tuple_bug(self):
|
|
|
|
|
class T(torch.Tensor):
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_function__(cls, *args, **kwargs):
|
|
|
|
|
pass
|
|
|
|
|
a = torch.rand(3)
|
|
|
|
|
|
|
|
|
|
a[[T(), T()]]
|
2022-01-20 02:07:36 +00:00
|
|
|
|
2022-06-09 00:57:43 +00:00
|
|
|
def test_standard_is_not_subclass(self):
|
|
|
|
|
# https://github.com/pytorch/pytorch/issues/79079
|
|
|
|
|
self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))
|
|
|
|
|
|
2023-08-24 14:33:15 +00:00
|
|
|
def test_sym_sizes_strides_slow_path(self):
|
|
|
|
|
class TestTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
|
|
|
cls, (0,), dispatch_sizes_strides_policy="sizes")
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
|
if func in (
|
|
|
|
|
torch.ops.aten.sym_size.default,
|
|
|
|
|
torch.ops.aten.sym_stride.default
|
|
|
|
|
):
|
|
|
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic
|
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
|
si = shape_env.create_symintnode(
|
|
|
|
|
shape_env.create_symbol(
|
|
|
|
|
123,
|
|
|
|
|
source=ConstantSource("abc"),
|
|
|
|
|
dynamic_dim=DimDynamic.DUCK,
|
|
|
|
|
constraint_dim=None,
|
|
|
|
|
),
|
|
|
|
|
hint=123
|
|
|
|
|
)
|
|
|
|
|
return (si,)
|
|
|
|
|
|
|
|
|
|
t = TestTensor()
|
|
|
|
|
si = t.size()[0]
|
|
|
|
|
self.assertIsInstance(si, torch.SymInt)
|
|
|
|
|
si = t.stride()[0]
|
|
|
|
|
self.assertIsInstance(si, torch.SymInt)
|
|
|
|
|
|
2022-06-10 03:02:28 +00:00
|
|
|
def test_strides_slow_path(self):
|
|
|
|
|
for use_wrapper_subclass in [True, False]:
|
|
|
|
|
class StridesNotImplemented(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class StridesCustomReturn(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
2022-08-23 22:14:05 +00:00
|
|
|
if func == torch.ops.aten.sym_stride.default:
|
2022-06-10 03:02:28 +00:00
|
|
|
return (4, 2)
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class StridesDefaultReturn(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
2022-08-23 22:14:05 +00:00
|
|
|
if func == torch.ops.aten.sym_stride.default:
|
2022-06-10 03:02:28 +00:00
|
|
|
return None
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-23 22:58:14 +00:00
|
|
|
err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'"
|
2022-06-10 03:02:28 +00:00
|
|
|
e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
|
2023-01-30 12:49:29 +00:00
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
2022-06-10 03:02:28 +00:00
|
|
|
e.stride()
|
|
|
|
|
|
|
|
|
|
e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.stride(), (4, 2))
|
|
|
|
|
|
|
|
|
|
e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.stride(), (2, 1))
|
|
|
|
|
|
2022-06-13 18:07:07 +00:00
|
|
|
def test_sizes_slow_path(self):
|
|
|
|
|
for use_wrapper_subclass in [True, False]:
|
|
|
|
|
data = torch.randn(6, 2)
|
|
|
|
|
|
|
|
|
|
class SizesNotImplemented(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.aten.dim:
|
|
|
|
|
return data.dim()
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class SizesCustomReturn(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.aten.dim:
|
|
|
|
|
return data.dim()
|
2022-07-08 01:17:33 +00:00
|
|
|
if func.overloadpacket == torch.ops.aten.sym_size:
|
2022-06-13 18:07:07 +00:00
|
|
|
return (5, 3)
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class SizesDefaultReturn(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.aten.dim:
|
|
|
|
|
return data.dim()
|
2022-07-08 01:17:33 +00:00
|
|
|
if func.overloadpacket == torch.ops.aten.sym_size:
|
2022-06-13 18:07:07 +00:00
|
|
|
return None
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-23 22:58:14 +00:00
|
|
|
err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'"
|
2022-06-13 18:07:07 +00:00
|
|
|
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
|
2023-01-30 12:49:29 +00:00
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
2022-06-13 18:07:07 +00:00
|
|
|
e.size()
|
|
|
|
|
|
|
|
|
|
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.size(), (5, 3))
|
|
|
|
|
|
|
|
|
|
e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.size(), (4, 2))
|
|
|
|
|
|
2023-08-29 02:43:09 +00:00
|
|
|
def test_custom_size_policy_dynamic_shapes(self):
|
|
|
|
|
data = torch.randn(6, 2)
|
|
|
|
|
|
|
|
|
|
class CustomSizeDynamicShapesTensor(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, inner):
|
|
|
|
|
return torch.Tensor._make_wrapper_subclass(
|
|
|
|
|
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
|
|
|
|
|
# Calling the overload that has kwargs causes us to go down the first overload path,
|
|
|
|
|
# which will **always** specialize sizes.
|
|
|
|
|
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
|
|
|
|
|
cls,
|
|
|
|
|
inner.size(),
|
|
|
|
|
inner.stride(),
|
|
|
|
|
None,
|
|
|
|
|
None,
|
|
|
|
|
inner.dtype,
|
|
|
|
|
inner.layout,
|
|
|
|
|
inner.device,
|
|
|
|
|
False,
|
|
|
|
|
inner.requires_grad,
|
|
|
|
|
"sizes",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def __init__(self, inner):
|
|
|
|
|
self.inner = inner
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func == torch.ops.aten.sym_size.default:
|
|
|
|
|
return args[0].inner.shape
|
|
|
|
|
if func == torch.ops.aten.sym_stride.default:
|
|
|
|
|
return args[0].inner.shape
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
x = torch.ones(2, 2)
|
|
|
|
|
|
|
|
|
|
def trace_fn(x):
|
|
|
|
|
x_wrapper = CustomSizeDynamicShapesTensor(x)
|
|
|
|
|
return x_wrapper.size(), x_wrapper.stride()
|
|
|
|
|
fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
|
|
|
|
|
self.assertExpectedInline(fx_g.code.strip(), """\
|
|
|
|
|
def forward(self, x_1):
|
2023-11-06 19:38:49 +00:00
|
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
|
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
|
|
|
|
|
return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""")
|
2023-08-29 02:43:09 +00:00
|
|
|
|
2023-05-12 15:59:48 +00:00
|
|
|
def test_data_ptr_respects_numel_slow_path(self):
|
|
|
|
|
data = torch.randn(6, 2)
|
|
|
|
|
|
|
|
|
|
class NumelDefaultReturn(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.aten.dim:
|
|
|
|
|
return data.dim()
|
2023-11-07 20:17:30 +00:00
|
|
|
if func.overloadpacket == torch.ops.aten.numel:
|
2023-05-12 15:59:48 +00:00
|
|
|
numel_called[0] = True
|
|
|
|
|
return None
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
for use_wrapper_subclass in (False, True):
|
|
|
|
|
numel_called = [False]
|
|
|
|
|
e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass)
|
|
|
|
|
e.data_ptr()
|
|
|
|
|
self.assertTrue(numel_called[0])
|
|
|
|
|
|
2022-07-05 16:43:21 +00:00
|
|
|
def test_layout_slow_path(self):
|
|
|
|
|
for use_wrapper_subclass in [True, False]:
|
|
|
|
|
data = torch.randn(6, 2)
|
|
|
|
|
|
|
|
|
|
class LayoutNotImplemented(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_layout=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class LayoutCustomReturn(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_layout=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.prim.layout:
|
|
|
|
|
return torch.sparse_csr
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
|
class LayoutDefaultReturn(torch.Tensor):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def __new__(cls, data, wrapper):
|
|
|
|
|
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_layout=True)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
|
|
|
if func.overloadpacket == torch.ops.prim.layout:
|
|
|
|
|
return data.layout
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-23 22:58:14 +00:00
|
|
|
err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'"
|
2022-07-05 16:43:21 +00:00
|
|
|
e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
with self.assertRaisesRegex(TypeError, err_msg):
|
|
|
|
|
e.layout
|
|
|
|
|
|
|
|
|
|
e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.layout, torch.sparse_csr)
|
|
|
|
|
|
|
|
|
|
e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
|
|
|
|
|
self.assertEqual(e.layout, torch.strided)
|
|
|
|
|
|
2022-09-15 00:43:36 +00:00
|
|
|
class TestPythonDispatcher(TestCase):
|
|
|
|
|
def test_basic(self):
|
|
|
|
|
x = torch.randn(2, requires_grad=True)
|
|
|
|
|
r = torch._C._EnablePythonDispatcher()
|
|
|
|
|
torch.add(x, x)
|
|
|
|
|
|
2022-09-20 18:07:46 +00:00
|
|
|
def test_lstsq(self):
|
|
|
|
|
a = torch.randn(4, 3)
|
|
|
|
|
b = torch.rand(4, 3)
|
|
|
|
|
expected_shape = torch.linalg.lstsq(a, b).solution.shape
|
|
|
|
|
r = torch._C._EnablePythonDispatcher()
|
|
|
|
|
python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
|
|
|
|
|
self.assertEqual(expected_shape, python_disp_shape)
|
|
|
|
|
|
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 02:43:08 +00:00
|
|
|
class TestWrapperSubclassAliasing(TestCase):
|
|
|
|
|
|
|
|
|
|
def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
|
|
|
|
|
def to_subclass(t: torch.Tensor):
|
|
|
|
|
return TwoTensor(t, t.clone())
|
|
|
|
|
|
|
|
|
|
result_ref = op(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
|
|
|
|
|
kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)
|
|
|
|
|
|
|
|
|
|
result_test = op(*args_subclass, **kwargs_subclass)
|
|
|
|
|
|
2023-10-30 18:25:51 +00:00
|
|
|
args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs)
|
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 02:43:08 +00:00
|
|
|
args_ref_flat_tensors = [x for x in args_ref_flat if isinstance(x, torch.Tensor)]
|
|
|
|
|
|
2023-10-30 00:05:29 +00:00
|
|
|
args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass))
|
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 02:43:08 +00:00
|
|
|
args_test_flat_tensors = [x for x in args_test_flat if isinstance(x, torch.Tensor)]
|
|
|
|
|
|
2023-10-30 00:05:29 +00:00
|
|
|
result_ref_flat = pytree.tree_leaves(result_ref)
|
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 02:43:08 +00:00
|
|
|
result_ref_flat_tensors = [x for x in result_ref_flat if isinstance(x, torch.Tensor)]
|
|
|
|
|
|
2023-10-30 00:05:29 +00:00
|
|
|
result_test_flat = pytree.tree_leaves(result_test)
|
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 02:43:08 +00:00
|
|
|
result_test_flat_tensors = [x for x in result_test_flat if isinstance(x, torch.Tensor)]
|
|
|
|
|
|
|
|
|
|
for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
|
|
|
|
|
for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
|
|
|
|
|
out_is_inpt = o_ref is a_ref
|
|
|
|
|
if out_is_inpt:
|
|
|
|
|
self.assertTrue(o_test is a_test)
|
|
|
|
|
|
|
|
|
|
out_aliases_inpt = StorageWeakRef(o_ref.untyped_storage()) == StorageWeakRef(a_ref.untyped_storage())
|
|
|
|
|
if out_aliases_inpt:
|
|
|
|
|
self.assertTrue(StorageWeakRef(o_test.untyped_storage()) == StorageWeakRef(a_test.untyped_storage()))
|
|
|
|
|
else:
|
|
|
|
|
self.assertFalse(StorageWeakRef(o_test.untyped_storage()) == StorageWeakRef(a_test.untyped_storage()))
|
|
|
|
|
|
|
|
|
|
# This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
|
|
|
|
|
# a util for wrapper subclasses to promise correct aliasing behavior.
|
|
|
|
|
# It's probably overkill to test every OpInfo,
|
|
|
|
|
# so I picked a sampling of ops with representative schemas.
|
|
|
|
|
@ops([op for op in op_db if op.name in [
|
|
|
|
|
'mul', # out-of-place
|
|
|
|
|
'cat', # out-of-place (TensorList input)
|
|
|
|
|
'index', # out-of-place (Optional TensorList input)
|
|
|
|
|
'mul_', # inplace
|
|
|
|
|
'view', # view
|
|
|
|
|
't_', # inplace-view
|
|
|
|
|
'split', # view (multi-return)
|
|
|
|
|
'native_batch_norm', # mutable op (returns outputs and mutates some inputs)
|
|
|
|
|
]], allowed_dtypes=(torch.float,))
|
|
|
|
|
def test_wrapper_subclass_aliasing(self, device, dtype, op):
|
|
|
|
|
samples = op.sample_inputs(device, dtype)
|
|
|
|
|
sample = first_sample(self, samples)
|
|
|
|
|
args = (sample.input, *sample.args)
|
|
|
|
|
kwargs = sample.kwargs
|
|
|
|
|
self._test_wrapper_subclass_aliasing(op, args, kwargs)
|
|
|
|
|
|
|
|
|
|
@ops(custom_op_db, allowed_dtypes=(torch.float,))
|
|
|
|
|
def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
|
|
|
|
|
samples = op.sample_inputs(device, dtype)
|
|
|
|
|
sample = first_sample(self, samples)
|
|
|
|
|
args = (sample.input, *sample.args)
|
|
|
|
|
kwargs = sample.kwargs
|
|
|
|
|
self._test_wrapper_subclass_aliasing(op, args, kwargs)
|
|
|
|
|
|
2023-09-15 16:58:21 +00:00
|
|
|
def test_wrapper_subclass_aliasing_conv2d(self, device):
|
|
|
|
|
args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
|
|
|
|
|
kwargs = {}
|
|
|
|
|
# conv2d has a default arg 'int[2] strides=0',
|
|
|
|
|
# which torchscript expands into 'int[2] strides=[0, 0]'
|
|
|
|
|
# Make sure that _return_and_correct_aliasing can handle this case
|
|
|
|
|
# (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
self._test_wrapper_subclass_aliasing(torch.ops.aten.conv2d.default, args, kwargs)
|
|
|
|
|
|
2023-09-21 20:37:11 +00:00
|
|
|
def test_wrapper_subclass_aliasing_out_op(self, device):
|
|
|
|
|
# Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors
|
|
|
|
|
args = (torch.ones(4), torch.ones(4))
|
|
|
|
|
kwargs = {'out': torch.empty(4)}
|
|
|
|
|
self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)
|
|
|
|
|
|
add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.
Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.
One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):
(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)
(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation
(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)
(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?
Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 02:43:08 +00:00
|
|
|
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
|
|
|
|
|
|
2021-08-12 18:39:31 +00:00
|
|
|
if __name__ == '__main__':
|
Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760
See https://github.com/pytorch/pytorch/issues/59049
There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.
**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.
**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`.
**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.
**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.
**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.
**Known limitations.**
* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D29017912
D29017912
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Pulled By: ezyang
fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 18:49:20 +00:00
|
|
|
run_tests()
|