This PR adds a new `FunctionalTensor` subclass, and `FunctionalTensorMode` torch dispatch mode. Together, this class/mode are a lightweight wrapper around our existing C++ functionalization logic.
This idea came from Ed - later in the stack, I want to be able to run functionalization **underneath** torch_dispatch, when performing tracing in AOTAutograd. I can't do this easily with vanilla C++ functionalization, because it has a dedicated dispatch key that always runs before TorchDispatch. However, by adding a torch_dispatch mode shim around functionalization, we can use functionalization as a torch_dispatch mode, which will make it easier to run underneath other modes later.
This PR provides the basic new classes, and some light testing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106404
Approved by: https://github.com/ezyang
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT
Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
* For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106786
Approved by: https://github.com/zou3519
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates:
(1) TLS changes in `TorchDispatchModeTLS.h/cpp`.
I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack).
`TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes.
`TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes".
`TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes).
There were also some mild codegen changes to support the new enum
(2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py`
The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`.
I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly.
(2) dispatching logic in `python_arg_parser.cpp`
This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now:
```
(a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last)
(b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses)
(c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses)
```
Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass.
How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key.
(3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends).
----- original description below -----
The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit
Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that.
**(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode**
This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes.
I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set.
This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler.
**(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()**
This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs.
**(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's**
Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active.
**(4) Allow traceable tensor subclasses to access storages from python**
This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable"
**(5) Fixed subclass fakeification**
@wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct.
**(6) make tensor subclasses resizeable**
Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`.
**(7) Added a basic DoubleTensor subclass for testing**
I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482
Approved by: https://github.com/ezyang
ghstack dependencies: #107415
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
Summary: Basically we generate `CustomOpsNativeFunctions.h` for registering custom ops into PyTorch JIT runtime. This header needs to hookup with the C++ kernel implementation of all the custom ops. For this reason it should include ATen headers instead of Executorch headers. This PR changes it.
Test Plan: Rely on existing CI jobs
Differential Revision: D48282828
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107064
Approved by: https://github.com/kirklandsign
This fixes a bug that could occur with python decompositions.
When an operation is intercepted in the c++ code in pytorch the outputs a created as `ExclusivelyOwned<at::Tensor>`s. Later on when it dispatches back to python for the decomposition these tensors have their ownership shared with python. In a normal use case the exclusively owned tensor is released and it's value returned as a non-exclusively owned tensor from the operation. However if the python decomposition throws an error the `ExclusivelyOwned` wrapper destroys the `at::Tensor` leading to a python reference to a tensor which isn't alive (and meaning pytorch falls over in debug mode).
Note this will be a performance hit when handling errors.
Fixes#106790
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106791
Approved by: https://github.com/ezyang
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
* Enables PIE807 + PIE810. PIE807 is do not reimplement list builtin function using lambda and PIE810 is to always fuse startswith / endswith calls (I applied the autofixes for this before we had ruff enabled).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106218
Approved by: https://github.com/albanD
Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf
Hide all Float8 operator implementations behind `#if !defined(C10_MOBILE)` guard to keep Android build size almost unchanged
TODO:
- Refactor duplicated code
- Cleanup unbalanced pragma pop in dtype utils
- Add native implementation on the CUDA size
Co-authored-by: Nikita Shulga <nshulga@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104242
Approved by: https://github.com/albanD
Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf
Hide all Float8 operator implementations behind `#if !defined(C10_MOBILE)` guard to keep Android build size almost unchanged
TODO:
- Refactor duplicated code
- Cleanup unbalanced pragma pop in dtype utils
- Add native implementation on the CUDA size
Co-authored-by: Nikita Shulga <nshulga@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104242
Approved by: https://github.com/albanD
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
Based on this [code search](https://fburl.com/code/gjcnw8ly) (*.yaml with `dispatch: CPU:`), update all files found to use
```
kernels:
- arg_meta: None
kernel_name:
```
instead of
```
dispatch:
CPU:
```
---
## Code changes:
- `fbcode/executorch/codegen/tools/gen_oplist.py`
- Strip ET specific fields prior to calling parse_native_yaml_struct
---
## Files edited that are not `*functions.yaml` or `custom_ops.yaml`
- fbcode/executorch/kernels/optimized/optimized.yaml
- fbcode/executorch/kernels/quantized/quantized.yaml
- fbcode/executorch/kernels/test/custom_kernel_example/my_functions.yaml
---
## Found Files that were not edited
**Dispatched to more than just CPU**
- fbcode/caffe2/aten/src/ATen/native/native_functions.yaml
- xplat/caffe2/aten/src/ATen/native/native_functions.yaml
- xros/third-party/caffe2/caffe2/aten/src/ATen/native/native_functions.yaml
**Grouped ops.yaml path**
- fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/ops.yaml
---
**Design Doc:** https://docs.google.com/document/d/1gq4Wz2R6verKJ2EFseLyPdAF0wqomnCrVDDJpRkYsRw/edit?kh_source=GDOCS#heading=h.8raqyft9y50
Differential Revision: [D46952067](https://our.internmc.facebook.com/intern/diff/D46952067/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D46952067/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104070
Approved by: https://github.com/larryliu0820
Summary: Currently we rely on root operator, but we also need to check for et_kernel_metadata for used specialized kernels.
Test Plan: contbuild & OSS CI
Reviewed By: Jack-Khuu
Differential Revision: D46882119
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104005
Approved by: https://github.com/Jack-Khuu
Fixes https://github.com/pytorch/pytorch/issues/103132
This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.
The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.
For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.
I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (@zou3519 any chance you remember?).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103275
Approved by: https://github.com/ezyang, https://github.com/zou3519
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))
inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```
The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).
The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.
**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```
Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
Summary:
This API is used by the gen_executorch.py to check whether a kernel with specified kernel key is used or not.
Test Plan:
```
buck test xplat/caffe2/tools:test_torchgen_executorch
buck run fbcode//executorch/codegen/tools:test_gen_oplist_real_model
```
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103184
Approved by: https://github.com/larryliu0820