pytorch/c10/core/impl/TorchDispatchModeTLS.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

197 lines
6.7 KiB
C++
Raw Normal View History

#include <c10/core/DispatchKey.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
#include <c10/util/irange.h>
#include <utility>
namespace c10::impl {
[Modes] remove enable and rewrite mode stack (squashed) (#84774) Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
thread_local TorchDispatchModeTLS torchDispatchModeState;
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) {
if (!torchDispatchModeState.stack_.empty())
return true;
if (!skip_infra_modes) {
for (const auto i : c10::irange(
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
return true;
}
}
}
return false;
}
void TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
std::shared_ptr<PyObject_TorchDispatchMode> mode) {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
if (!any_modes_set()) {
[Modes] remove enable and rewrite mode stack (squashed) (#84774) Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
torchDispatchModeState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<PyObject_TorchDispatchMode> TorchDispatchModeTLS::
pop_stack() {
std::shared_ptr<PyObject_TorchDispatchMode> out;
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
if (!torchDispatchModeState.stack_.empty()) {
out = torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
} else {
for (int64_t i =
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
out = std::move(torchDispatchModeState.infra_modes_[i].value());
torchDispatchModeState.infra_modes_[i] = std::nullopt;
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
break;
}
}
}
TORCH_CHECK(out, "trying to pop from empty mode stack");
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
[Modes] remove enable and rewrite mode stack (squashed) (#84774) Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
return out;
}
const std::
tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
TorchDispatchModeTLS::pop_highest_infra_mode() {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
auto out_mode = torchDispatchModeState.infra_modes_[i].value();
torchDispatchModeState.infra_modes_[i] = std::nullopt;
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return std::make_tuple(
std::move(out_mode), static_cast<TorchDispatchModeKey>(i));
}
}
TORCH_CHECK(
false, "Called pop_highest_infra_mode, but no infra modes were active.")
}
[Modes] remove enable and rewrite mode stack (squashed) (#84774) Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
const std::shared_ptr<PyObject_TorchDispatchMode>& TorchDispatchModeTLS::
get_stack_at(int64_t idx) {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big");
// Our "logical" stack includes both:
// - any user modes (the entire torchDispatchModeState.stack_)
// - any infra modes (members of torchDispatchModeState.infra_modes_ that are
// not None)
// idx == 0 means the "bottom" of the stack, which starts with any infra
// modes (iterating from lowest-priority to highest-priority).
auto curr_idx = idx;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
if (curr_idx == 0) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
return torchDispatchModeState.infra_modes_[i].value();
}
curr_idx -= 1;
}
}
// At this point, we're guaranteed that curr_idx < stack_.size()
return torchDispatchModeState.stack_[curr_idx];
[Modes] remove enable and rewrite mode stack (squashed) (#84774) Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
}
int64_t TorchDispatchModeTLS::stack_len() {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
auto stack_len = static_cast<int64_t>(torchDispatchModeState.stack_.size());
int64_t infra_modes_len = 0;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
infra_modes_len += 1;
}
}
return stack_len + infra_modes_len;
}
const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
TorchDispatchModeTLS::get_mode(TorchDispatchModeKey mode_key) {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
return torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
}
void TorchDispatchModeTLS::set_mode(
const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
TorchDispatchModeKey mode_key) {
TORCH_CHECK(
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] ==
std::nullopt,
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
"trying to set the current ",
to_string(mode_key),
", but one already exists");
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] = mode;
}
const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) {
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
auto out = torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] =
std::nullopt;
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
if (out.has_value() && !any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
[Modes] remove enable and rewrite mode stack (squashed) (#84774) Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
}
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
}
void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
torchDispatchModeState = std::move(state);
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
} else {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
[Modes] remove enable and rewrite mode stack (squashed) (#84774) Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-26 20:42:07 +00:00
}
// UTIL
bool dispatch_mode_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
reorder proxy / fake modes so they always run last (#104482) **Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
2023-08-28 20:29:13 +00:00
TorchDispatchModeTLS::any_modes_set();
}
std::string to_string(TorchDispatchModeKey mode_key) {
switch (mode_key) {
case TorchDispatchModeKey::PROXY:
return "ProxyTorchDispatchMode";
case TorchDispatchModeKey::FAKE:
return "FakeTensorMode";
default:
return "UNKNOWN_MODE";
}
}
} // namespace c10::impl