[foreach_map] Initial foreach map HOP impl for inference (#142098)

This is the initial foreach map HOP for pointwise ops which will be extended in the future to support grouped GEMMs and other ops.

This PR utilizes PrimHOPBase class to represent foreach_map as a HOP with a single subgraph. The way this is implemented is that the user API `foreach_map` provides a single pointwise torch op, and internally this function calls a polyfill which has the same semantics as a foreach op (ie iterates over lists of operands applying the op elementwise). The higher order op is passed through the stack down to inductor where a lowering in essence inlines the subgraph into the main graph. This is done by interpreting it with a pointwise subgraph lowering, grouping the outputs by device, and registering the output buffers as foreach groups as applicable. For testing I was able to reuse the existing foreach tests by creating a wrapper function which matches the foreach op interfaces for those tests and then run all of the existing foreach tests on foreach_map.

TODO before landing:
* Add tests for general functions
* Test warning if unsupported op will block fusion

Followups:
* I need to add tests for backwards (this will be a followup PR because backwards will  require other work as well)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142098
Approved by: https://github.com/eellison
This commit is contained in:
Michael Lazos 2024-12-11 21:32:11 +00:00 committed by PyTorch MergeBot
parent bd199bc754
commit de313f1155
7 changed files with 259 additions and 35 deletions

View file

@ -5,6 +5,7 @@ import unittest
import torch
import torch._inductor
from torch._higher_order_ops import foreach_map
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -31,23 +32,64 @@ except (unittest.SkipTest, ImportError) as e:
sys.exit(0)
raise
def foreach_map_wrapper(op):
def wrapper(*args, **kwargs):
return foreach_map(op, (args), **kwargs)
wrapper.__name__ = "foreach_map_" + op.__name__
return wrapper
def add_op(x, y):
return torch.add(x, y)
def addrecip_op(x, y):
return torch.reciprocal(torch.add(x, y))
def addcmul_op(x, y, z):
return torch.mul(torch.add(x, y), z)
def recipaddmul_op(x, y, z):
return torch.mul(torch.add(torch.reciprocal(x), y), z)
inplace_bin_ops_under_test = [
torch._foreach_add_,
torch._foreach_mul_,
torch._foreach_sub_,
torch._foreach_div_,
]
ternary_ops_under_test = [
foreach_map_wrapper(addcmul_op),
foreach_map_wrapper(recipaddmul_op),
]
bin_ops_under_test = [
torch._foreach_add,
torch._foreach_mul,
torch._foreach_sub,
torch._foreach_div,
foreach_map_wrapper(torch.add),
foreach_map_wrapper(torch.mul),
foreach_map_wrapper(torch.sub),
foreach_map_wrapper(torch.div),
foreach_map_wrapper(addrecip_op),
foreach_map_wrapper(add_op),
torch._foreach_maximum,
torch._foreach_minimum,
torch._foreach_clamp_max,
torch._foreach_clamp_min,
aten._foreach_copy,
foreach_map_wrapper(torch.maximum),
foreach_map_wrapper(torch.minimum),
foreach_map_wrapper(torch.clamp_max),
foreach_map_wrapper(torch.clamp_min),
foreach_map_wrapper(aten.copy),
]
un_ops_under_test = [
@ -57,18 +99,26 @@ un_ops_under_test = [
torch._foreach_abs,
torch._foreach_sqrt,
torch._foreach_rsqrt,
foreach_map_wrapper(torch.reciprocal),
foreach_map_wrapper(torch.neg),
foreach_map_wrapper(torch.sign),
foreach_map_wrapper(torch.abs),
]
compose_ops = [torch._foreach_addcdiv, torch._foreach_addcmul]
all_ops = parametrize(
"op", bin_ops_under_test + un_ops_under_test, name_fn=lambda f: f.__name__
"op",
ternary_ops_under_test + bin_ops_under_test + un_ops_under_test,
name_fn=lambda f: f.__name__,
)
bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)
inplace_bin_ops = parametrize(
"op", inplace_bin_ops_under_test, name_fn=lambda f: f.__name__
)
scalar_bin_ops = parametrize("op", bin_ops_under_test[:4], name_fn=lambda f: f.__name__)
scalar_bin_ops = parametrize(
"op", bin_ops_under_test[:10], name_fn=lambda f: f.__name__
)
scalar_tensor_bin_ops = parametrize(
"op", bin_ops_under_test[:2], name_fn=lambda f: f.__name__
"op", bin_ops_under_test[:10], name_fn=lambda f: f.__name__
)
decomp_ops = parametrize("op", compose_ops, name_fn=lambda f: f.__name__)
@ -79,12 +129,21 @@ def gen_args(op):
torch.rand(10, 10, device="cuda:0"),
torch.rand(20, 20, device="cuda:0"),
)
elif op in bin_ops_under_test:
return (
torch.rand(10, 10, device="cuda:0"),
torch.rand(20, 20, device="cuda:0"),
torch.rand(10, 10, device="cuda:0"),
torch.rand(20, 20, device="cuda:0"),
)
else:
return (
torch.rand(10, 10, device="cuda:0"),
torch.rand(20, 20, device="cuda:0"),
torch.rand(10, 10, device="cuda:0"),
torch.rand(20, 20, device="cuda:0"),
torch.rand(10, 10, device="cuda:0"),
torch.rand(20, 20, device="cuda:0"),
)
@ -108,11 +167,16 @@ class ForeachTests(TestCase):
def fn(a0, a1):
return op([a0, a1])
else:
elif op in bin_ops_under_test:
def fn(a0, a1, b0, b1):
return op([a0, a1], [b0, b1])
else:
def fn(a0, a1, b0, b1, c0, c1):
return op([a0, a1], [b0, b1], [c0, c1])
self.check_model_cuda(
fn,
gen_args(op),
@ -174,12 +238,18 @@ class ForeachTests(TestCase):
c = op([a0, a1])
return torch._foreach_sqrt(c)
else:
elif op in bin_ops_under_test:
def fn(a0, a1, b0, b1):
c = op([a0, a1], [b0, b1])
return c, torch._foreach_add([a0, a1], c)
else:
def fn(a0, a1, b0, b1, c0, c1):
c = op([a0, a1], [b0, b1], [c0, c1])
return c, torch._foreach_add([a0, a1], c)
self.check_model_cuda(
fn,
gen_args(op),
@ -232,7 +302,7 @@ class ForeachTests(TestCase):
return op([a0])
args = (torch.rand(10, 10, device="cuda:0"),)
else:
elif op in bin_ops_under_test:
def fn(a0, b0):
return op([a0], [b0])
@ -242,6 +312,17 @@ class ForeachTests(TestCase):
torch.rand(10, 10, device="cuda:0"),
)
else:
def fn(a0, b0, c0):
return op([a0], [b0], [c0])
args = (
torch.rand(10, 10, device="cuda:0"),
torch.rand(10, 10, device="cuda:0"),
torch.rand(10, 10, device="cuda:0"),
)
self.check_model_cuda(
fn,
args,
@ -342,12 +423,18 @@ class ForeachTests(TestCase):
c = op([a0, a1])
return torch.mul(c[0], a0)
else:
elif op in bin_ops_under_test:
def fn(a0, a1, b0, b1):
c = op([a0, a1], [b0, b1])
return torch.mul(c[0], a0)
else:
def fn(a0, a1, b0, b1, c0, c1):
c = op([a0, a1], [b0, b1], [c0, c1])
return torch.mul(c[0], a0)
self.check_model_cuda(
fn,
gen_args(op),
@ -382,13 +469,20 @@ class ForeachTests(TestCase):
c1 = torch.add(a1, a1)
return op([c0, c1])
else:
elif op in bin_ops_under_test:
def fn(a0, a1, b0, b1):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
return op([a0, a1], [c0, c1])
else:
def fn(a0, a1, b0, b1, c0, c1):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
return op([a0, a1], [b0, b1], [c0, c1])
self.check_model_cuda(
fn, gen_args(op), reference_in_float=False, check_lowp=False
)
@ -428,7 +522,7 @@ class ForeachTests(TestCase):
e1 = torch.mul(d[1], a1)
return [e0, e1]
else:
elif op in bin_ops_under_test:
def fn(a0, a1, b0, b1):
c0 = torch.add(a0, b0)
@ -438,6 +532,16 @@ class ForeachTests(TestCase):
e1 = torch.mul(d[1], a1)
return [e0, e1]
else:
def fn(a0, a1, b0, b1, c0, c1):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
d = op([a0, a1], [b0, b1], [c0, c1])
e0 = torch.mul(d[0], a0)
e1 = torch.mul(d[1], a1)
return [e0, e1]
self.check_model_cuda(
fn,
gen_args(op),

View file

@ -8,7 +8,8 @@ Python polyfills for common builtins.
# mypy: allow-untyped-defs
from typing import Any, Callable, Sequence, TYPE_CHECKING
from itertools import repeat as _repeat
from typing import Any, Callable, List, Sequence, TYPE_CHECKING
import torch
@ -177,6 +178,28 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
return obj
def foreach_map_fn(*args):
op = args[0]
new_args: List[Any] = []
at_least_one_list = False
for arg in args[1:]:
if not isinstance(arg, (list, tuple)):
new_args.append(_repeat(arg))
else:
at_least_one_list = True
new_args.append(arg)
# Just apply op once to args if there are no lists
if not at_least_one_list:
return op(*args[1:])
out = []
for unpacked in zip(*new_args):
out.append(op(*unpacked))
return out
def foreach_lerp_inplace(self, end, weight):
# decompose foreach lerp into constituent ops, prevents a graph break due to
# converting a value to a scalar when arg[2] is a single tensor

View file

@ -293,6 +293,7 @@ manual_torch_name_rule_map = {
"torch._functorch.deprecated.grad_and_value": UserFunctionVariable,
"torch._functorch.deprecated.vjp": UserFunctionVariable,
# everything else
"torch._higher_order_ops.foreach_map.foreach_map": UserFunctionVariable,
"torch._constrain_as_size": UserFunctionVariable,
"torch._tensor._convert": UserFunctionVariable,
"torch.jit._unwrap_optional": UserFunctionVariable,

View file

@ -3,6 +3,7 @@ from torch._higher_order_ops.flex_attention import (
flex_attention,
flex_attention_backward,
)
from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
from torch._higher_order_ops.prim_hop_base import PrimHOPBase
@ -19,4 +20,6 @@ __all__ = [
"flex_attention_backward",
"hints_wrapper",
"PrimHOPBase",
"foreach_map",
"_foreach_map",
]

View file

@ -0,0 +1,26 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Any, Callable, Dict, Tuple
from torch._higher_order_ops.prim_hop_base import FunctionWithNoFreeVars, PrimHOPBase
class ForeachMap(PrimHOPBase):
def __init__(self):
super().__init__("foreach_map")
def __call__(self, fn, operands, *unused, **kwargs): # type: ignore[override]
fn = FunctionWithNoFreeVars(fn)
return super().__call__(fn, operands, **kwargs)
_foreach_map = ForeachMap()
def foreach_map(
op: Callable, operands: Any, *unused: Tuple[Any], **kwargs: Dict[str, Any]
):
from torch._dynamo.polyfills import foreach_map_fn
args = (op,) + operands
return _foreach_map(foreach_map_fn, args, **kwargs)

View file

@ -9,6 +9,7 @@ import operator
import os
import warnings
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from unittest.mock import patch
@ -93,12 +94,46 @@ aten = torch.ops.aten
tr_c10d = torch.ops.tr_c10d
prims = torch.ops.prims
needs_realized_inputs: Set[torch._ops.OpOverload] = set()
foreach_ops: Set[torch._ops.OpOverload] = set()
foreach_ops: Set[torch._ops.OpOverload] = {torch._higher_order_ops._foreach_map} # type: ignore[arg-type]
inplace_foreach_ops: Set[torch._ops.OpOverload] = set()
inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
quantized_decomposed = torch.ops.quantized_decomposed
def cur_node_has_non_foreach_users():
for node in V.graph.current_node.users:
for user in node.users:
if not (user.op == "call_function" and (user.target in foreach_ops)):
return True
return False
# group by device, whether any of the inputs are dynamic
# note arg_pairs may or may not be a pair
# foreach_map for example just passes output buffers here
def group_foreach_args(arg_pairs: Iterable[Union[Tuple[Any, Any], Any]]):
out = defaultdict(list)
unpack_args = False
for i, args in enumerate(arg_pairs):
if not isinstance(args, Iterable):
unpack_args = True
args = (args,)
use_foreach = (
not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
)
device = None
for t in args:
if isinstance(t, TensorBox):
device = t.data.get_device()
break
assert device is not None, "foreach op should have at least one tensor arg"
if unpack_args:
(args,) = args
out[(device, use_foreach)].append((i, args))
return out
def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]:
"""Get layout constraints. Returns None if there are no layout constraints."""
if not isinstance(fn, torch._ops.OpOverload):
@ -608,33 +643,11 @@ def make_pointwise(
def make_foreach_pointwise(pw_fn, allow_alpha=False):
def inner(*inputs: List[List[TensorBox]], alpha=1):
# group by device, whether any of the inputs are dynamic, and whether their types match
# (proxy for type promotion)
def group_args(arg_pairs):
out = defaultdict(list)
for i, args in enumerate(arg_pairs):
use_foreach = (
not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
)
device = None
for t in args:
if isinstance(t, TensorBox):
device = t.data.get_device()
break
assert (
device is not None
), "foreach op should have at least one tensor arg"
out[(device, use_foreach)].append((i, args))
return out
realize_outputs = (
len(V.graph.current_node.users) == 0
or V.graph.current_node.target in inplace_foreach_ops
or cur_node_has_non_foreach_users()
)
for node in V.graph.current_node.users:
for user in node.users:
if not (user.op == "call_function" and (user.target in foreach_ops)):
realize_outputs = True
a_list_input = None
for input in inputs:
@ -653,7 +666,7 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
else:
broadcast_inputs.append(input)
groups = group_args(zip(*broadcast_inputs))
groups = group_foreach_args(zip(*broadcast_inputs))
outputs = [None] * len(a_list_input)
for (device, use_foreach), group in groups.items():
@ -697,6 +710,59 @@ def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False):
return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
@register_lowering(torch._higher_order_ops._foreach_map)
def _foreach_map(subgraph, *args, **kwargs):
"""
This lowers an invocation of foreach_map
The way this works is that an arbitrary N-arg func is provided by the user, looped over by the
polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args)
and then traced into a subgraph by dynamo.
This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering.
The graph outputs represent the vertically fused sequence of ops, and then register_operation_list
below registers the buffers as horizontally fuseable in the scheduler.
"""
realize_outputs = (
len(V.graph.current_node.users) == 0 or cur_node_has_non_foreach_users()
)
from .subgraph_lowering import PointwiseSubgraphLowering
inputs = args[0] # nested tuple
gm = subgraph.graph_module
pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*inputs)
sub_outputs = pw_subgraph.graph_outputs
# group outputs by device and register as foreach
assert sub_outputs # mypy lol
groups = group_foreach_args(sub_outputs)
outputs = [None] * len(sub_outputs)
for (device, use_foreach), group in groups.items():
operation_list: List[str] = []
for (
output_ind,
output,
) in group:
outputs[output_ind] = output
if (
V.graph.has_feature(device, BackendFeature.FOREACH)
and use_foreach
and realize_outputs
):
output.realize()
operation_list.append(output.get_operation_name())
if operation_list:
V.graph.register_operation_list(operation_list)
assert all(x is not None for x in outputs)
return outputs
@register_lowering(prims.convert_element_type, type_promotion_kind=None)
def _convert_element_type(x: TensorBox, dtype: torch.dtype):
if dtype.is_complex or x.get_dtype().is_complex:

View file

@ -74,6 +74,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
"triton_kernel_wrapper_mutation",
"triton_kernel_wrapper_functional",
"hints_wrapper",
"foreach_map",
]
torch.library.define(