mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[foreach_map] Initial foreach map HOP impl for inference (#142098)
This is the initial foreach map HOP for pointwise ops which will be extended in the future to support grouped GEMMs and other ops. This PR utilizes PrimHOPBase class to represent foreach_map as a HOP with a single subgraph. The way this is implemented is that the user API `foreach_map` provides a single pointwise torch op, and internally this function calls a polyfill which has the same semantics as a foreach op (ie iterates over lists of operands applying the op elementwise). The higher order op is passed through the stack down to inductor where a lowering in essence inlines the subgraph into the main graph. This is done by interpreting it with a pointwise subgraph lowering, grouping the outputs by device, and registering the output buffers as foreach groups as applicable. For testing I was able to reuse the existing foreach tests by creating a wrapper function which matches the foreach op interfaces for those tests and then run all of the existing foreach tests on foreach_map. TODO before landing: * Add tests for general functions * Test warning if unsupported op will block fusion Followups: * I need to add tests for backwards (this will be a followup PR because backwards will require other work as well) Pull Request resolved: https://github.com/pytorch/pytorch/pull/142098 Approved by: https://github.com/eellison
This commit is contained in:
parent
bd199bc754
commit
de313f1155
7 changed files with 259 additions and 35 deletions
|
|
@ -5,6 +5,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
import torch._inductor
|
||||
from torch._higher_order_ops import foreach_map
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
|
|
@ -31,23 +32,64 @@ except (unittest.SkipTest, ImportError) as e:
|
|||
sys.exit(0)
|
||||
raise
|
||||
|
||||
|
||||
def foreach_map_wrapper(op):
|
||||
def wrapper(*args, **kwargs):
|
||||
return foreach_map(op, (args), **kwargs)
|
||||
|
||||
wrapper.__name__ = "foreach_map_" + op.__name__
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_op(x, y):
|
||||
return torch.add(x, y)
|
||||
|
||||
|
||||
def addrecip_op(x, y):
|
||||
return torch.reciprocal(torch.add(x, y))
|
||||
|
||||
|
||||
def addcmul_op(x, y, z):
|
||||
return torch.mul(torch.add(x, y), z)
|
||||
|
||||
|
||||
def recipaddmul_op(x, y, z):
|
||||
return torch.mul(torch.add(torch.reciprocal(x), y), z)
|
||||
|
||||
|
||||
inplace_bin_ops_under_test = [
|
||||
torch._foreach_add_,
|
||||
torch._foreach_mul_,
|
||||
torch._foreach_sub_,
|
||||
torch._foreach_div_,
|
||||
]
|
||||
ternary_ops_under_test = [
|
||||
foreach_map_wrapper(addcmul_op),
|
||||
foreach_map_wrapper(recipaddmul_op),
|
||||
]
|
||||
|
||||
bin_ops_under_test = [
|
||||
torch._foreach_add,
|
||||
torch._foreach_mul,
|
||||
torch._foreach_sub,
|
||||
torch._foreach_div,
|
||||
foreach_map_wrapper(torch.add),
|
||||
foreach_map_wrapper(torch.mul),
|
||||
foreach_map_wrapper(torch.sub),
|
||||
foreach_map_wrapper(torch.div),
|
||||
foreach_map_wrapper(addrecip_op),
|
||||
foreach_map_wrapper(add_op),
|
||||
torch._foreach_maximum,
|
||||
torch._foreach_minimum,
|
||||
torch._foreach_clamp_max,
|
||||
torch._foreach_clamp_min,
|
||||
aten._foreach_copy,
|
||||
foreach_map_wrapper(torch.maximum),
|
||||
foreach_map_wrapper(torch.minimum),
|
||||
foreach_map_wrapper(torch.clamp_max),
|
||||
foreach_map_wrapper(torch.clamp_min),
|
||||
foreach_map_wrapper(aten.copy),
|
||||
]
|
||||
|
||||
un_ops_under_test = [
|
||||
|
|
@ -57,18 +99,26 @@ un_ops_under_test = [
|
|||
torch._foreach_abs,
|
||||
torch._foreach_sqrt,
|
||||
torch._foreach_rsqrt,
|
||||
foreach_map_wrapper(torch.reciprocal),
|
||||
foreach_map_wrapper(torch.neg),
|
||||
foreach_map_wrapper(torch.sign),
|
||||
foreach_map_wrapper(torch.abs),
|
||||
]
|
||||
compose_ops = [torch._foreach_addcdiv, torch._foreach_addcmul]
|
||||
all_ops = parametrize(
|
||||
"op", bin_ops_under_test + un_ops_under_test, name_fn=lambda f: f.__name__
|
||||
"op",
|
||||
ternary_ops_under_test + bin_ops_under_test + un_ops_under_test,
|
||||
name_fn=lambda f: f.__name__,
|
||||
)
|
||||
bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)
|
||||
inplace_bin_ops = parametrize(
|
||||
"op", inplace_bin_ops_under_test, name_fn=lambda f: f.__name__
|
||||
)
|
||||
scalar_bin_ops = parametrize("op", bin_ops_under_test[:4], name_fn=lambda f: f.__name__)
|
||||
scalar_bin_ops = parametrize(
|
||||
"op", bin_ops_under_test[:10], name_fn=lambda f: f.__name__
|
||||
)
|
||||
scalar_tensor_bin_ops = parametrize(
|
||||
"op", bin_ops_under_test[:2], name_fn=lambda f: f.__name__
|
||||
"op", bin_ops_under_test[:10], name_fn=lambda f: f.__name__
|
||||
)
|
||||
decomp_ops = parametrize("op", compose_ops, name_fn=lambda f: f.__name__)
|
||||
|
||||
|
|
@ -79,12 +129,21 @@ def gen_args(op):
|
|||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(20, 20, device="cuda:0"),
|
||||
)
|
||||
elif op in bin_ops_under_test:
|
||||
return (
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(20, 20, device="cuda:0"),
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(20, 20, device="cuda:0"),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(20, 20, device="cuda:0"),
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(20, 20, device="cuda:0"),
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(20, 20, device="cuda:0"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -108,11 +167,16 @@ class ForeachTests(TestCase):
|
|||
def fn(a0, a1):
|
||||
return op([a0, a1])
|
||||
|
||||
else:
|
||||
elif op in bin_ops_under_test:
|
||||
|
||||
def fn(a0, a1, b0, b1):
|
||||
return op([a0, a1], [b0, b1])
|
||||
|
||||
else:
|
||||
|
||||
def fn(a0, a1, b0, b1, c0, c1):
|
||||
return op([a0, a1], [b0, b1], [c0, c1])
|
||||
|
||||
self.check_model_cuda(
|
||||
fn,
|
||||
gen_args(op),
|
||||
|
|
@ -174,12 +238,18 @@ class ForeachTests(TestCase):
|
|||
c = op([a0, a1])
|
||||
return torch._foreach_sqrt(c)
|
||||
|
||||
else:
|
||||
elif op in bin_ops_under_test:
|
||||
|
||||
def fn(a0, a1, b0, b1):
|
||||
c = op([a0, a1], [b0, b1])
|
||||
return c, torch._foreach_add([a0, a1], c)
|
||||
|
||||
else:
|
||||
|
||||
def fn(a0, a1, b0, b1, c0, c1):
|
||||
c = op([a0, a1], [b0, b1], [c0, c1])
|
||||
return c, torch._foreach_add([a0, a1], c)
|
||||
|
||||
self.check_model_cuda(
|
||||
fn,
|
||||
gen_args(op),
|
||||
|
|
@ -232,7 +302,7 @@ class ForeachTests(TestCase):
|
|||
return op([a0])
|
||||
|
||||
args = (torch.rand(10, 10, device="cuda:0"),)
|
||||
else:
|
||||
elif op in bin_ops_under_test:
|
||||
|
||||
def fn(a0, b0):
|
||||
return op([a0], [b0])
|
||||
|
|
@ -242,6 +312,17 @@ class ForeachTests(TestCase):
|
|||
torch.rand(10, 10, device="cuda:0"),
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def fn(a0, b0, c0):
|
||||
return op([a0], [b0], [c0])
|
||||
|
||||
args = (
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
torch.rand(10, 10, device="cuda:0"),
|
||||
)
|
||||
|
||||
self.check_model_cuda(
|
||||
fn,
|
||||
args,
|
||||
|
|
@ -342,12 +423,18 @@ class ForeachTests(TestCase):
|
|||
c = op([a0, a1])
|
||||
return torch.mul(c[0], a0)
|
||||
|
||||
else:
|
||||
elif op in bin_ops_under_test:
|
||||
|
||||
def fn(a0, a1, b0, b1):
|
||||
c = op([a0, a1], [b0, b1])
|
||||
return torch.mul(c[0], a0)
|
||||
|
||||
else:
|
||||
|
||||
def fn(a0, a1, b0, b1, c0, c1):
|
||||
c = op([a0, a1], [b0, b1], [c0, c1])
|
||||
return torch.mul(c[0], a0)
|
||||
|
||||
self.check_model_cuda(
|
||||
fn,
|
||||
gen_args(op),
|
||||
|
|
@ -382,13 +469,20 @@ class ForeachTests(TestCase):
|
|||
c1 = torch.add(a1, a1)
|
||||
return op([c0, c1])
|
||||
|
||||
else:
|
||||
elif op in bin_ops_under_test:
|
||||
|
||||
def fn(a0, a1, b0, b1):
|
||||
c0 = torch.add(a0, b0)
|
||||
c1 = torch.add(a1, b1)
|
||||
return op([a0, a1], [c0, c1])
|
||||
|
||||
else:
|
||||
|
||||
def fn(a0, a1, b0, b1, c0, c1):
|
||||
c0 = torch.add(a0, b0)
|
||||
c1 = torch.add(a1, b1)
|
||||
return op([a0, a1], [b0, b1], [c0, c1])
|
||||
|
||||
self.check_model_cuda(
|
||||
fn, gen_args(op), reference_in_float=False, check_lowp=False
|
||||
)
|
||||
|
|
@ -428,7 +522,7 @@ class ForeachTests(TestCase):
|
|||
e1 = torch.mul(d[1], a1)
|
||||
return [e0, e1]
|
||||
|
||||
else:
|
||||
elif op in bin_ops_under_test:
|
||||
|
||||
def fn(a0, a1, b0, b1):
|
||||
c0 = torch.add(a0, b0)
|
||||
|
|
@ -438,6 +532,16 @@ class ForeachTests(TestCase):
|
|||
e1 = torch.mul(d[1], a1)
|
||||
return [e0, e1]
|
||||
|
||||
else:
|
||||
|
||||
def fn(a0, a1, b0, b1, c0, c1):
|
||||
c0 = torch.add(a0, b0)
|
||||
c1 = torch.add(a1, b1)
|
||||
d = op([a0, a1], [b0, b1], [c0, c1])
|
||||
e0 = torch.mul(d[0], a0)
|
||||
e1 = torch.mul(d[1], a1)
|
||||
return [e0, e1]
|
||||
|
||||
self.check_model_cuda(
|
||||
fn,
|
||||
gen_args(op),
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ Python polyfills for common builtins.
|
|||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
from typing import Any, Callable, Sequence, TYPE_CHECKING
|
||||
from itertools import repeat as _repeat
|
||||
from typing import Any, Callable, List, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -177,6 +178,28 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
|
|||
return obj
|
||||
|
||||
|
||||
def foreach_map_fn(*args):
|
||||
op = args[0]
|
||||
new_args: List[Any] = []
|
||||
at_least_one_list = False
|
||||
for arg in args[1:]:
|
||||
if not isinstance(arg, (list, tuple)):
|
||||
new_args.append(_repeat(arg))
|
||||
else:
|
||||
at_least_one_list = True
|
||||
new_args.append(arg)
|
||||
|
||||
# Just apply op once to args if there are no lists
|
||||
if not at_least_one_list:
|
||||
return op(*args[1:])
|
||||
|
||||
out = []
|
||||
for unpacked in zip(*new_args):
|
||||
out.append(op(*unpacked))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def foreach_lerp_inplace(self, end, weight):
|
||||
# decompose foreach lerp into constituent ops, prevents a graph break due to
|
||||
# converting a value to a scalar when arg[2] is a single tensor
|
||||
|
|
|
|||
|
|
@ -293,6 +293,7 @@ manual_torch_name_rule_map = {
|
|||
"torch._functorch.deprecated.grad_and_value": UserFunctionVariable,
|
||||
"torch._functorch.deprecated.vjp": UserFunctionVariable,
|
||||
# everything else
|
||||
"torch._higher_order_ops.foreach_map.foreach_map": UserFunctionVariable,
|
||||
"torch._constrain_as_size": UserFunctionVariable,
|
||||
"torch._tensor._convert": UserFunctionVariable,
|
||||
"torch.jit._unwrap_optional": UserFunctionVariable,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from torch._higher_order_ops.flex_attention import (
|
|||
flex_attention,
|
||||
flex_attention_backward,
|
||||
)
|
||||
from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map
|
||||
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
||||
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
|
||||
from torch._higher_order_ops.prim_hop_base import PrimHOPBase
|
||||
|
|
@ -19,4 +20,6 @@ __all__ = [
|
|||
"flex_attention_backward",
|
||||
"hints_wrapper",
|
||||
"PrimHOPBase",
|
||||
"foreach_map",
|
||||
"_foreach_map",
|
||||
]
|
||||
|
|
|
|||
26
torch/_higher_order_ops/foreach_map.py
Normal file
26
torch/_higher_order_ops/foreach_map.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase
|
||||
|
||||
|
||||
class ForeachMap(PrimHOPBase):
|
||||
def __init__(self):
|
||||
super().__init__("foreach_map")
|
||||
|
||||
def __call__(self, fn, operands, *unused, **kwargs): # type: ignore[override]
|
||||
fn = FunctionWithNoFreeVars(fn)
|
||||
return super().__call__(fn, operands, **kwargs)
|
||||
|
||||
|
||||
_foreach_map = ForeachMap()
|
||||
|
||||
|
||||
def foreach_map(
|
||||
op: Callable, operands: Any, *unused: Tuple[Any], **kwargs: Dict[str, Any]
|
||||
):
|
||||
from torch._dynamo.polyfills import foreach_map_fn
|
||||
|
||||
args = (op,) + operands
|
||||
return _foreach_map(foreach_map_fn, args, **kwargs)
|
||||
|
|
@ -9,6 +9,7 @@ import operator
|
|||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -93,12 +94,46 @@ aten = torch.ops.aten
|
|||
tr_c10d = torch.ops.tr_c10d
|
||||
prims = torch.ops.prims
|
||||
needs_realized_inputs: Set[torch._ops.OpOverload] = set()
|
||||
foreach_ops: Set[torch._ops.OpOverload] = set()
|
||||
foreach_ops: Set[torch._ops.OpOverload] = {torch._higher_order_ops._foreach_map} # type: ignore[arg-type]
|
||||
inplace_foreach_ops: Set[torch._ops.OpOverload] = set()
|
||||
inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
|
||||
def cur_node_has_non_foreach_users():
|
||||
for node in V.graph.current_node.users:
|
||||
for user in node.users:
|
||||
if not (user.op == "call_function" and (user.target in foreach_ops)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# group by device, whether any of the inputs are dynamic
|
||||
# note arg_pairs may or may not be a pair
|
||||
# foreach_map for example just passes output buffers here
|
||||
def group_foreach_args(arg_pairs: Iterable[Union[Tuple[Any, Any], Any]]):
|
||||
out = defaultdict(list)
|
||||
unpack_args = False
|
||||
for i, args in enumerate(arg_pairs):
|
||||
if not isinstance(args, Iterable):
|
||||
unpack_args = True
|
||||
args = (args,)
|
||||
use_foreach = (
|
||||
not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
|
||||
)
|
||||
device = None
|
||||
for t in args:
|
||||
if isinstance(t, TensorBox):
|
||||
device = t.data.get_device()
|
||||
break
|
||||
assert device is not None, "foreach op should have at least one tensor arg"
|
||||
if unpack_args:
|
||||
(args,) = args
|
||||
out[(device, use_foreach)].append((i, args))
|
||||
return out
|
||||
|
||||
|
||||
def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]:
|
||||
"""Get layout constraints. Returns None if there are no layout constraints."""
|
||||
if not isinstance(fn, torch._ops.OpOverload):
|
||||
|
|
@ -608,33 +643,11 @@ def make_pointwise(
|
|||
|
||||
def make_foreach_pointwise(pw_fn, allow_alpha=False):
|
||||
def inner(*inputs: List[List[TensorBox]], alpha=1):
|
||||
# group by device, whether any of the inputs are dynamic, and whether their types match
|
||||
# (proxy for type promotion)
|
||||
def group_args(arg_pairs):
|
||||
out = defaultdict(list)
|
||||
for i, args in enumerate(arg_pairs):
|
||||
use_foreach = (
|
||||
not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
|
||||
)
|
||||
device = None
|
||||
for t in args:
|
||||
if isinstance(t, TensorBox):
|
||||
device = t.data.get_device()
|
||||
break
|
||||
assert (
|
||||
device is not None
|
||||
), "foreach op should have at least one tensor arg"
|
||||
out[(device, use_foreach)].append((i, args))
|
||||
return out
|
||||
|
||||
realize_outputs = (
|
||||
len(V.graph.current_node.users) == 0
|
||||
or V.graph.current_node.target in inplace_foreach_ops
|
||||
or cur_node_has_non_foreach_users()
|
||||
)
|
||||
for node in V.graph.current_node.users:
|
||||
for user in node.users:
|
||||
if not (user.op == "call_function" and (user.target in foreach_ops)):
|
||||
realize_outputs = True
|
||||
|
||||
a_list_input = None
|
||||
for input in inputs:
|
||||
|
|
@ -653,7 +666,7 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
|
|||
else:
|
||||
broadcast_inputs.append(input)
|
||||
|
||||
groups = group_args(zip(*broadcast_inputs))
|
||||
groups = group_foreach_args(zip(*broadcast_inputs))
|
||||
|
||||
outputs = [None] * len(a_list_input)
|
||||
for (device, use_foreach), group in groups.items():
|
||||
|
|
@ -697,6 +710,59 @@ def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False):
|
|||
return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
|
||||
|
||||
|
||||
@register_lowering(torch._higher_order_ops._foreach_map)
|
||||
def _foreach_map(subgraph, *args, **kwargs):
|
||||
"""
|
||||
This lowers an invocation of foreach_map
|
||||
The way this works is that an arbitrary N-arg func is provided by the user, looped over by the
|
||||
polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args)
|
||||
and then traced into a subgraph by dynamo.
|
||||
This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering.
|
||||
The graph outputs represent the vertically fused sequence of ops, and then register_operation_list
|
||||
below registers the buffers as horizontally fuseable in the scheduler.
|
||||
"""
|
||||
realize_outputs = (
|
||||
len(V.graph.current_node.users) == 0 or cur_node_has_non_foreach_users()
|
||||
)
|
||||
|
||||
from .subgraph_lowering import PointwiseSubgraphLowering
|
||||
|
||||
inputs = args[0] # nested tuple
|
||||
|
||||
gm = subgraph.graph_module
|
||||
pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
|
||||
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
|
||||
pw_subgraph.run(*inputs)
|
||||
|
||||
sub_outputs = pw_subgraph.graph_outputs
|
||||
# group outputs by device and register as foreach
|
||||
assert sub_outputs # mypy lol
|
||||
groups = group_foreach_args(sub_outputs)
|
||||
|
||||
outputs = [None] * len(sub_outputs)
|
||||
for (device, use_foreach), group in groups.items():
|
||||
operation_list: List[str] = []
|
||||
for (
|
||||
output_ind,
|
||||
output,
|
||||
) in group:
|
||||
outputs[output_ind] = output
|
||||
|
||||
if (
|
||||
V.graph.has_feature(device, BackendFeature.FOREACH)
|
||||
and use_foreach
|
||||
and realize_outputs
|
||||
):
|
||||
output.realize()
|
||||
operation_list.append(output.get_operation_name())
|
||||
|
||||
if operation_list:
|
||||
V.graph.register_operation_list(operation_list)
|
||||
|
||||
assert all(x is not None for x in outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
@register_lowering(prims.convert_element_type, type_promotion_kind=None)
|
||||
def _convert_element_type(x: TensorBox, dtype: torch.dtype):
|
||||
if dtype.is_complex or x.get_dtype().is_complex:
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
|
|||
"triton_kernel_wrapper_mutation",
|
||||
"triton_kernel_wrapper_functional",
|
||||
"hints_wrapper",
|
||||
"foreach_map",
|
||||
]
|
||||
|
||||
torch.library.define(
|
||||
|
|
|
|||
Loading…
Reference in a new issue