pytorch/torch/_inductor
eellison 92b7e610ab [Inductor changes] Invoke Quant (#139102)
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).

The primary motivations are

- Unifying scattered reasoning for quant operators throughout the code base

- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:

```
        @register_graph_pattern(
            CallFunction(
                torch.ops.aten.mm,
                CallFunction(
                    torch.ops.higher_order.invoke_quant,
                    Ignored(),
                    Ignored(),
                    Ignored(),
                    scheme="nf4",
                ),
                Arg(),
            ),
            pass_dict=test_pass,
        )
```

- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.

Example graph:

``` Python
 ===== AFTER POST GRAD =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
         # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
        repeated_subgraph0 = self.repeated_subgraph0
        invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');  repeated_subgraph0 = arg0_1 = arg1_1 = None
        return (invoke_quant,)

    class repeated_subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
             # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
            mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None
            add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1);  mul = arg1_1 = None
            return add
```

The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.

I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.

```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.

Feedback welcome.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
2025-02-08 19:30:19 +00:00
..
autoheuristic
codegen [Inductor] Expand Identity ops prior to block pattern matching (#146000) 2025-02-08 18:11:53 +00:00
compile_worker
fx_passes [Inductor changes] Invoke Quant (#139102) 2025-02-08 19:30:19 +00:00
kernel Fix broken meta function for flex-attention backwards (#146563) 2025-02-08 04:13:52 +00:00
package
runtime [inductor/profiler] add kernel kwargs instrumentation (#145573) 2025-02-07 17:44:30 +00:00
__init__.py
analyze_preserves_zero_mask.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
aoti_eager.py
async_compile.py
autotune_process.py
bounds.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
choices.py
codecache.py
comm_analysis.py
comm_lowering.py
comms.py
compile_fx.py
compiler_bisector.py
config.py [inductor] use ftz variant of exp (#146216) 2025-02-06 19:12:35 +00:00
constant_folding.py
cpp_builder.py [cpp_builder] refactor to reduce libcudart_static logs (#146394) 2025-02-05 00:41:30 +00:00
cpu_vec_isa.py [CPUInductor] Fix SVE256 detection (#146207) 2025-02-01 18:51:34 +00:00
cudagraph_trees.py fix incorrect literal strings / accidental tuples (#146037) 2025-02-03 15:08:11 +00:00
cudagraph_utils.py
custom_graph_pass.py
debug.py
decomposition.py
dependencies.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
dtype_propagation.py [inductor] Minor compile time optimizations in DefaultHandler (#146282) 2025-02-08 18:00:40 +00:00
exc.py
extern_node_serializer.py
freezing.py
freezing_utils.py
fuzzer.py fuzzer: disable "fail_on_recompile_limit_hit" and "suppress_errors" (#146650) 2025-02-07 18:25:00 +00:00
fx_utils.py
graph.py cpp_wrapper: fix CPU cpp_wrapper and max-autotune tests (#145683) 2025-02-04 22:05:59 +00:00
hooks.py
index_propagation.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
inductor_prims.py
ir.py [inductor] Refactor op handlers part 3 (#146254) 2025-02-08 18:00:08 +00:00
jagged_lowerings.py
loop_body.py [inductor] Refactor CaptureIndexing into global scope (#146297) 2025-02-08 18:00:49 +00:00
lowering.py [Inductor changes] Invoke Quant (#139102) 2025-02-08 19:30:19 +00:00
memory.py
metrics.py
mkldnn_ir.py
mkldnn_lowerings.py
mock_cache.py
ops_handler.py [inductor] Minor compile time optimizations in DefaultHandler (#146282) 2025-02-08 18:00:40 +00:00
optimize_indexing.py
output_code.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
pattern_matcher.py
quantized_lowerings.py
remote_cache.py
scheduler.py
script.ld
select_algorithm.py [inductor] Refactor op handlers part 2 (#146252) 2025-02-08 18:00:00 +00:00
sizevars.py [inductor] Pre-populate cache for simplify_with_ranges return value (#146373) 2025-02-08 18:00:49 +00:00
subgraph_lowering.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
test_case.py
test_operators.py
triton_bundler.py
utils.py [inductor] Better exception error messages for cache_on_self (#146652) 2025-02-07 21:22:21 +00:00
virtualized.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
wrapper_benchmark.py