pytorch/torch/_higher_order_ops
eellison 92b7e610ab [Inductor changes] Invoke Quant (#139102)
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).

The primary motivations are

- Unifying scattered reasoning for quant operators throughout the code base

- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:

```
        @register_graph_pattern(
            CallFunction(
                torch.ops.aten.mm,
                CallFunction(
                    torch.ops.higher_order.invoke_quant,
                    Ignored(),
                    Ignored(),
                    Ignored(),
                    scheme="nf4",
                ),
                Arg(),
            ),
            pass_dict=test_pass,
        )
```

- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.

Example graph:

``` Python
 ===== AFTER POST GRAD =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
         # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
        repeated_subgraph0 = self.repeated_subgraph0
        invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');  repeated_subgraph0 = arg0_1 = arg1_1 = None
        return (invoke_quant,)

    class repeated_subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
             # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
            mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None
            add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1);  mul = arg1_1 = None
            return add
```

The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.

I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.

```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.

Feedback welcome.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
2025-02-08 19:30:19 +00:00
..
__init__.py [Inductor changes] Invoke Quant (#139102) 2025-02-08 19:30:19 +00:00
_invoke_quant.py [Inductor changes] Invoke Quant (#139102) 2025-02-08 19:30:19 +00:00
aoti_call_delegate.py Introduce aoti_call_delegate HOP (#145630) 2025-01-31 04:57:36 +00:00
associative_scan.py Require that all HOPs be imported at import torch time (#145939) 2025-01-29 22:27:52 +00:00
auto_functionalize.py [auto_functionalized] Support Tensor(a!)[]? (#145400) 2025-02-05 14:52:39 +00:00
cond.py [cond] remove warning for unsupported tuple returns (#145766) 2025-01-28 03:13:36 +00:00
effects.py
executorch_call_delegate.py
flat_apply.py Barebones flat_apply HOP (#146060) 2025-02-01 16:17:48 +00:00
flex_attention.py Fix broken meta function for flex-attention backwards (#146563) 2025-02-08 04:13:52 +00:00
foreach_map.py
hints_wrap.py
invoke_subgraph.py
map.py
out_dtype.py
prim_hop_base.py
run_const_graph.py
scan.py [scan] scan dim handling in user-facing scan() (#145179) 2025-01-30 21:09:07 +00:00
strict_mode.py
torchbind.py
triton_kernel_wrap.py [inductor] Make triton kernel autotune config defaults backward-compatible (#145494) 2025-01-29 00:31:39 +00:00
utils.py [while_loop] specialize when cond_fn return constants (#144515) 2025-01-30 19:02:34 +00:00
while_loop.py [hop] fix unbacked_bindings meta for while_loop (#143559) 2025-01-30 21:33:09 +00:00
wrap.py Require that all HOPs be imported at import torch time (#145939) 2025-01-29 22:27:52 +00:00