pytorch/test/dynamo/test_flat_apply.py
rzou 0f768c7866 Barebones flat_apply HOP (#146060)
This PR:
- adds pytree.register_constant for registering a class to be treated as
  a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
  mark_traceable working. A lot more work is necessary to get the custom
  operator case working (when make_fx sees a custom operator with PyTree
  arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
  the current state to unblock the mark_traceable and custom ops
  work.

Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
  I added a really simple test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059
2025-02-01 16:17:48 +00:00

66 lines
1.5 KiB
Python

# Owner(s): ["module: dynamo", "module: higher order operators"]
from dataclasses import dataclass
import torch
import torch._dynamo.test_case
import torch.utils._pytree as pytree
def distance(a, b, norm):
if norm.typ == "l2":
return torch.sqrt((a.x - b.x).pow(2) + (a.y - b.y).pow(2))
elif norm.typ == "l1":
return (a.x - b.x).abs() + (a.y - b.y).abs()
@dataclass
class Norm:
typ: str
pytree.register_constant(Norm)
@dataclass
class Point:
x: torch.Tensor
y: torch.Tensor
pytree.register_dataclass(Point)
class FlatApplyTests(torch._dynamo.test_case.TestCase):
def test_simple(self):
tensor = torch.tensor
a = Point(tensor(0.0), tensor(0.0))
b = Point(tensor(3.0), tensor(4.0))
norm = Norm("l2")
args = (a, b)
kwargs = {"norm": norm}
from torch._higher_order_ops.flat_apply import (
ConstantFunction,
flat_apply,
is_graphable,
to_graphable,
)
empty_list, func_spec = pytree.tree_flatten(ConstantFunction(distance))
self.assertEqual(empty_list, [])
flat_args, in_spec = to_graphable((args, kwargs))
for arg in flat_args:
self.assertTrue(is_graphable(arg))
# Test flat_apply returns same thing as original function
result = flat_apply(func_spec, in_spec, *flat_args)
self.assertEqual(result, distance(*args, **kwargs))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()