mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR: - adds pytree.register_constant for registering a class to be treated as a constant by torch.compile/torch.fx - adds a very barebones flat_apply HOP. This should be sufficient to get mark_traceable working. A lot more work is necessary to get the custom operator case working (when make_fx sees a custom operator with PyTree arg types, it needs to emit a call to the flat_apply HOP). - I expect the flat_apply HOP to change a lot, I want to ship this in the current state to unblock the mark_traceable and custom ops work. Test Plan: - It's kind of difficult to test the barebones flat_apply HOP "works" so I added a really simple test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060 Approved by: https://github.com/StrongerXi, https://github.com/yanboliang ghstack dependencies: #146059
66 lines
1.5 KiB
Python
66 lines
1.5 KiB
Python
# Owner(s): ["module: dynamo", "module: higher order operators"]
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch.utils._pytree as pytree
|
|
|
|
|
|
def distance(a, b, norm):
|
|
if norm.typ == "l2":
|
|
return torch.sqrt((a.x - b.x).pow(2) + (a.y - b.y).pow(2))
|
|
elif norm.typ == "l1":
|
|
return (a.x - b.x).abs() + (a.y - b.y).abs()
|
|
|
|
|
|
@dataclass
|
|
class Norm:
|
|
typ: str
|
|
|
|
|
|
pytree.register_constant(Norm)
|
|
|
|
|
|
@dataclass
|
|
class Point:
|
|
x: torch.Tensor
|
|
y: torch.Tensor
|
|
|
|
|
|
pytree.register_dataclass(Point)
|
|
|
|
|
|
class FlatApplyTests(torch._dynamo.test_case.TestCase):
|
|
def test_simple(self):
|
|
tensor = torch.tensor
|
|
|
|
a = Point(tensor(0.0), tensor(0.0))
|
|
b = Point(tensor(3.0), tensor(4.0))
|
|
norm = Norm("l2")
|
|
|
|
args = (a, b)
|
|
kwargs = {"norm": norm}
|
|
from torch._higher_order_ops.flat_apply import (
|
|
ConstantFunction,
|
|
flat_apply,
|
|
is_graphable,
|
|
to_graphable,
|
|
)
|
|
|
|
empty_list, func_spec = pytree.tree_flatten(ConstantFunction(distance))
|
|
self.assertEqual(empty_list, [])
|
|
|
|
flat_args, in_spec = to_graphable((args, kwargs))
|
|
|
|
for arg in flat_args:
|
|
self.assertTrue(is_graphable(arg))
|
|
|
|
# Test flat_apply returns same thing as original function
|
|
result = flat_apply(func_spec, in_spec, *flat_args)
|
|
self.assertEqual(result, distance(*args, **kwargs))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|