pytorch/test/higher_order_ops/test_invoke_quant.py
eellison 92b7e610ab [Inductor changes] Invoke Quant (#139102)
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).

The primary motivations are

- Unifying scattered reasoning for quant operators throughout the code base

- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:

```
        @register_graph_pattern(
            CallFunction(
                torch.ops.aten.mm,
                CallFunction(
                    torch.ops.higher_order.invoke_quant,
                    Ignored(),
                    Ignored(),
                    Ignored(),
                    scheme="nf4",
                ),
                Arg(),
            ),
            pass_dict=test_pass,
        )
```

- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.

Example graph:

``` Python
 ===== AFTER POST GRAD =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
         # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
        repeated_subgraph0 = self.repeated_subgraph0
        invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');  repeated_subgraph0 = arg0_1 = arg1_1 = None
        return (invoke_quant,)

    class repeated_subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
             # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
            mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None
            add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1);  mul = arg1_1 = None
            return add
```

The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.

I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.

```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.

Feedback welcome.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
2025-02-08 19:30:19 +00:00

183 lines
5.1 KiB
Python

# Owner(s): ["module: higher order operators"]
# flake8: noqa: B950
import contextlib
import logging
import unittest
import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from torch._higher_order_ops import InvokeQuant
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
Ignored,
Match,
PatternMatcherPass,
register_graph_pattern,
)
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
invoke_quant_tracer = InvokeQuant()
@skipIfTorchDynamo("Not a torch._dynamo test")
class TestInvokeQuant(TestCase):
backend = ""
def test_simple(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return invoke_quant_tracer(
gn, (x, y), scheme="nf4", quant_options=invoke_quant_tracer
)[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_construct_inline(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return InvokeQuant(codegen_low_precision=False)(gn, (x, y), scheme="nf4")[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_inline(self):
def gn(x, y):
return (torch.mul(x, y) + y,)
def fn(x, y):
return InvokeQuant()(gn, (x, y), scheme="nf4")[0]
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
ref = gn(x, y)[0]
x_clone = x.clone().detach().requires_grad_(False)
y_clone = y.clone().detach().requires_grad_(False)
res = torch.compile(fn, backend=self.backend)(x_clone, y_clone)
self.assertEqual(ref, res)
def test_multiple(self):
torch._logging.set_logs(post_grad_graphs=True)
def gn(x, y):
return torch.mul(x, y) + y
def fn(x, y, z):
o1 = invoke_quant_tracer(gn, (x, y), scheme="nf4")
o2 = invoke_quant_tracer(gn, (y, z), scheme="nf4")
return o1 + o2
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
z = torch.randn(8, requires_grad=False)
ref = fn(x, y, z)
log_context = (
contextlib.nullcontext()
if self.backend != "inductor"
else self.assertLogs(logger="torch._inductor", level=logging.DEBUG)
)
with log_context as log:
res = torch.compile(fn, backend=self.backend)(x, y, z)
self.assertEqual(ref, res)
if self.backend == "inductor":
logs = "\n".join(r.getMessage() for r in log.records)
f = FileCheck()
f.check("AFTER POST GRAD")
f.check("subgraph0").check("subgraph1")
for _ in range(2):
f.check("torch.ops.higher_order.invoke_quant(").check_same("nf4")
f.run(logs)
class TestInvokeQuantEager(TestInvokeQuant):
backend = "eager"
class TestInvokeQuantAotEager(TestInvokeQuant):
backend = "aot_eager"
class TestInvokeQuantInductor(TestInvokeQuant):
backend = "inductor"
def test_pattern_matching(self):
counter = 0
test_pass = PatternMatcherPass()
def my_pass(g):
return test_pass.apply(g)
def gn(x, y):
return torch.mul(x, y) + y
def fn(x, y, z):
return invoke_quant_tracer(gn, (x, y), scheme="nf4") @ z
def fn_no_match(x, y, z):
return invoke_quant_tracer(gn, (x, y)) @ z
x = torch.randn(64, 64, requires_grad=False)
y = torch.randn(64, 64, requires_grad=False)
z = torch.randn(64, 64, requires_grad=False)
@register_graph_pattern(
CallFunction(
torch.ops.aten.mm,
CallFunction(
torch.ops.higher_order.invoke_quant,
Ignored(),
Ignored(),
Ignored(),
scheme="nf4",
),
Arg(),
),
pass_dict=test_pass,
)
def quant_matching(match: Match, *args, **kwargs):
nonlocal counter
counter += 1
with unittest.mock.patch(
"torch._inductor.config.post_grad_custom_pre_pass", my_pass
):
torch.compile(fn)(x, y, z)
self.assertTrue(counter == 1)
torch.compile(fn_no_match)(x, y, z)
self.assertTrue(counter == 1)
del TestInvokeQuant
if __name__ == "__main__":
run_tests()