Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).
The primary motivations are
- Unifying scattered reasoning for quant operators throughout the code base
- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:
```
@register_graph_pattern(
CallFunction(
torch.ops.aten.mm,
CallFunction(
torch.ops.higher_order.invoke_quant,
Ignored(),
Ignored(),
Ignored(),
scheme="nf4",
),
Arg(),
),
pass_dict=test_pass,
)
```
- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.
Example graph:
``` Python
===== AFTER POST GRAD =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
# File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg]
repeated_subgraph0 = self.repeated_subgraph0
invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None
return (invoke_quant,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
# File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg]
mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None
add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None
return add
```
The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.
I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.
```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.
Feedback welcome.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
# Summary
Fixes https://github.com/pytorch/pytorch/issues/146377
So what was the original problem: we were codegening a really weird epilogue:
```Python
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
xindex = index_k + 64*index_n + 64*off_hkv*ks2 + 128*off_zq*ks2
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 64*index_n + off_hkv*ks1, dk.shape)), dk, mask)
x5 = (xindex % ks3)
tmp2 = tl.load(out_ptr0 + (x5 + ks1*off_hkv), mask, eviction_policy='evict_last')
tl.store(out_ptr1 + (tl.broadcast_to(xindex, dk.shape)), tmp2, mask)
```
This epilogue was writing and then reading from overlapping regions of memory causing a race condition.
### Why were we generating this epilgoue
During the lowering we created a buffer w/ a different size/stride from the expected return strides. I :think this added an implicit node (for doing the permutation of this wrongly strided output to the the expected one from the meta func. The scheduler for some reason thought it was okay to fuse this into the epilogue, tbh I dont know why.
This fixes the broken meta func and the original repro. I will add a test but it is hard to pop, better than nothing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146563
Approved by: https://github.com/Chillee
This PR:
- adds pytree.register_constant for registering a class to be treated as
a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
mark_traceable working. A lot more work is necessary to get the custom
operator case working (when make_fx sees a custom operator with PyTree
arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
the current state to unblock the mark_traceable and custom ops
work.
Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
I added a really simple test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059
Summary:
Previously, aoti compile node is represented as a kernel-less custom op in the exported program. The node was not eager runnable, which is a common practice for numerical validation during lowering.
I introduce a new HOP to address this.
The schema is following
```
aoti_call_delegate(lower_moduel: AOTInductorEPModule, original_gm: fx.GraphModule, weights: List[Tensor], inputs: List[Tensor])
```
There are a few problems exposed by HOP
- AOTI expects a FX graph with weights as getattr nodes, aka stateful graph. HOP expect graph_module arguments to be stateless. Export serializer also expect a stateless graph. Currently, to make AOTI happy, I am making `original_gm` stateful, and bypassing the serialization for `original_gm`.
- As a result, the HOP is not re-traceable, as functionalization on stateful graph module argument will fail.
Test Plan: buck2 test 'fbcode//mode/opt' fbcode//deeplearning/aot_inductor/cpu/test:cpu_lowering_utils_test
Reviewed By: zhxchen17
Differential Revision: D68359391
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145630
Approved by: https://github.com/zou3519
E.g. torch.ops.higher_order.cond does not exist until it is imported,
which is bad if it shows up in an FX graph or is used in some code
somewhere.
This PR also makes some more HOPs get imported at `import torch` time.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145939
Approved by: https://github.com/ydwu4
ghstack dependencies: #145938
If a model was torch.packaged using triton<=3.1, any user-defined
autotuned kernels will have reps/warmups burned in with the old defaults
(100/25). If this model is loaded with triton>=3.2, inductor's checks for
unsupported non-default autotune args will fail, because triton.Autotuner's
defaults for these parameters has changed to `None`. Let's explicitly support
those values for backward compatibility with these older models.
Differential Revision: [D68561014](https://our.internmc.facebook.com/intern/diff/D68561014/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145494
Approved by: https://github.com/aorenste
This PR implements the user-facing dim change, i.e., that the scan dim provided by the user is always moved to dim 0 and then the associative_scan operation always operates on dim 0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139864
Approved by: https://github.com/ydwu4
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This PR fixes user-defined triton kernel handling (in most cases) for these new triton commits.
What this PR fixes:
* in triton_kernel_wrap.py, AST->TTIR parsing was to be updated for the new triton API
* ir.py - don't remove None args when using newer triton versions
* wrapper.py - update signature & constant handling
What this doesn't fix:
* correct None handling - I want to do a closer look at constant handling (including None, equal_to_1, and other constants).
* cpp wrapper (which needs to be fixed for both user-defined triton kernels and inductor-generated kernels)
test/inductor/test_triton_kernels.py passed on triton commit 74de6b46, with the exception of three tests (those shown here: 1374074098)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145348
Approved by: https://github.com/jansel
ghstack dependencies: #145051
This PR implements the idea of checking input mutations through tensor version and check aliasing via storage from @zou3519. Previously, we rely on whether there's a in place op that takes placeholder input, which doesn't take views into account.
When writing the PR, I also noticed a bug in previous input mutation checking logic: we were checking the whether there are operators functionalized_f where all the mutating ops have been replaced so we won't be able to detect any thing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145298
Approved by: https://github.com/zou3519
Summary: Introduce `is_hop_single_tensor_return` field to the `Node` class in serialization so that during deserialization when there is a single return, we know whether it is a tuple of a single element or a single element.
Test Plan:
```
buck2 run @mode/dev-nosan sigmoid/inference/test:e2e_test_cpu -- -r E2ETestCPUCond
buck2 run @mode/dev-nosan sigmoid/inference/test:test_passes -- -r test_const_folding2
```
Differential Revision: D66991624
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143227
Approved by: https://github.com/zhxchen17
We support running a single Autotuner for each Triton kernel. Currently,
if there are multiple autotuning decorators, the subsequent ones will be
silently ignored.
Instead, we should raise an error here to avoid silent incorrectness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143519
Approved by: https://github.com/aakhundov
FIXES https://github.com/pytorch/pytorch/issues/142313
So with previous HOPs, compiled autograd could just inline into their body and get their post-dispatch aten representation. You can't do that with this flex attention HOP, which just wants any proxy tracing mechanism to insert it into its graph. Okay, compiled autograd does use proxy tracing, so we can do that.
This is safe because other than the reenter_make_fx call, there were no other make_fx internals usage in the HOP. And compiled autograd specializes on the AOT backward's saved symints which should cover any changes in shapes to the inputs of the HOP.
However, there's still an issue: Dynamo doesn't know how to handle `FlexAttentionBackwardHOP` and will graph break, so the flex attention backward is running in eager as of this PR. The tlparse looks really scuffed after the compiled autograd capture: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpMMHBEH/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143155
Approved by: https://github.com/drisspg
This is the initial foreach map HOP for pointwise ops which will be extended in the future to support grouped GEMMs and other ops.
This PR utilizes PrimHOPBase class to represent foreach_map as a HOP with a single subgraph. The way this is implemented is that the user API `foreach_map` provides a single pointwise torch op, and internally this function calls a polyfill which has the same semantics as a foreach op (ie iterates over lists of operands applying the op elementwise). The higher order op is passed through the stack down to inductor where a lowering in essence inlines the subgraph into the main graph. This is done by interpreting it with a pointwise subgraph lowering, grouping the outputs by device, and registering the output buffers as foreach groups as applicable. For testing I was able to reuse the existing foreach tests by creating a wrapper function which matches the foreach op interfaces for those tests and then run all of the existing foreach tests on foreach_map.
TODO before landing:
* Add tests for general functions
* Test warning if unsupported op will block fusion
Followups:
* I need to add tests for backwards (this will be a followup PR because backwards will require other work as well)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142098
Approved by: https://github.com/eellison
The idea is the parent hop's fake tensor mode should ignore the newly allocated unbacked symints in subgraph because the bindings of unbacked symbols in the subgraph should already be done when we trace the subgraph. E.g. if there's an operator in subgraph that produces unbacked symints, the track_tensor_tree logic for that operator will take care of it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142031
Approved by: https://github.com/zou3519
ghstack dependencies: #142162
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.
Having these `# type: ignore` linger around is not ideal for two reasons:
- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.
I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.
This PR should have no effect on runtime at all.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
We've been using it privately for half a year and everything's been
good. This PR:
1. Makes torch.library.triton_op public
2. Renames capture_triton -> wrap_triton. We got feedback that no one
knew what "capture triton" does.
3. Makes torch.library.wrap_triton public.
triton_op is used to construct a Python custom operator that may call 1+
triton kernels. Each of those triton kernels must be annotated with
wrap_triton.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141880
Approved by: https://github.com/albanD
ghstack dependencies: #141894
This PR fixes the shape checks that are done in the associative_scan operation.
Before all shapes of the input leaves were required to be the same. With this PR only the shapes of the output of the combine_fn and the input leaves need to be the same, but not among the input leaves.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141698
Approved by: https://github.com/ydwu4
# Summary
The follow up PR to: https://github.com/pytorch/pytorch/pull/137526. In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads.
We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found. Along the way found some masking bugs.
There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now.
By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic.
## Worked Example
Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does.
ScoreMod:
```Python
bias = torch.randn(
params.seq_length,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
offset = torch.randint(
0,
params.seq_length,
(params.seq_length,),
device=self.device,
)
def score_mod(score, b, h, q_idx, kv_idx):
return score + bias[offset[q_idx]]
```
I am removing all but the new subgraph injected into the backwards:
``` Python
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
grad_scores = (dsT)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
tmp4 = (dsT).to(tl.float32)
tl.atomic_add(out_ptr1 + (tl.broadcast_to(tl.load(in_ptr16 + idx_m), tmp4.shape)), tmp4, scatter_mask, sem='relaxed')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
```
## Key points
* We always accumulate to float 32 grad buffers regardless of the type in the forward. This is because we normally do all computation intra kernel w/ fp32 accumulation and we want the same behavior for atomic additions
* We are currently restricted to 1 scatter in the kenrel. I have some ideas on fx rewrites that would remove this restrictions but for now have nice error message w/ work around and will leave as a follow up.
* Will do more extensive performance/ memory profiling in a follow up.
### Toy E2E example
I have a toy E2E training example PR in the gym for now: https://github.com/pytorch-labs/attention-gym/pull/84/
I plan to update to a realistic learnable bias before landing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137452
Approved by: https://github.com/Chillee
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
This PR adds caching for user defined triton kernels by putting the transitive closure of source code in node.meta along with constant arguments.
One HUGE hack we do here is a node looks like
```
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(1, 1, 1)], tma_descriptor_
metadata = {}, kwargs = {'in_ptr0': arg0_1, 'in_ptr1': arg1_1, 'out_ptr': arg0_1}, tensors_to_clone = ['out_ptr']);
```
so we use regex to remove `kernel_idx = 0, constant_args_idx = 1` parts as they are not relevant to cache hash. This is horrible and I'd like to eventually not use pickle as a hashing alternative but this is a longer project.
Differential Revision: [D65895744](https://our.internmc.facebook.com/intern/diff/D65895744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140326
Approved by: https://github.com/zou3519