Add Gelu Related Ops to Triton Codegen (#17713)

Add Gelu/QuickGelu/GeluGrad/QuickGeluGrad support to Triton Codegen so
that it can be fused with some other connected supported Ops. For
example, in llama2, it can be fused with Mul so we will have extra 1-2%
perf gain.
This commit is contained in:
Vincent Wang 2023-09-27 19:57:39 +08:00 committed by GitHub
parent a99c965d05
commit e6aa0fa174
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 110 additions and 20 deletions

View file

@ -263,10 +263,20 @@ class TritonCodegen(NodeVisitor):
"Rsqrt": "{indent}{o0} = 1.0 / tl.sqrt({i0})\n",
"Cast": "{indent}{o0} = {i0}.to(tl.{dtype})\n",
"CastBool": "{indent}{o0} = {i0} != 0\n",
"Erf": "{indent}{o0} = tl.libdevice.erf({i0})\n",
"Gelu": "{indent}{o0} = (tl.libdevice.erf({i0} / 1.41421356237) + 1.0) * 0.5\n",
"Erf": "{indent}{o0} = tl.erf({i0})\n",
"Gelu": "{indent}{o0} = {i0} * 0.5 * (tl.math.erf({i0} * 0.70710678118654752440) + 1.0)\n",
"QuickGelu": "{indent}{o0} = {i0} * tl.sigmoid({i0} * {alpha})\n",
"GeluGrad": (
"{indent}{o0} = {i0} * (0.5 * (1.0 + tl.math.erf(0.70710678118654752440 * {i1})) + "
"{i1} * 1.12837916709551257390 * 0.70710678118654752440 * 0.5 * tl.exp(-0.5 * {i1} * {i1}))\n"
),
"QuickGeluGrad": (
"{indent}tmp_v = {i1} * {alpha}\n"
"{indent}tmp_sigmoid = tl.sigmoid(tmp_v)\n"
"{indent}{o0} = {i0} * tmp_sigmoid * (1.0 + tmp_v * (1.0 - tmp_sigmoid))\n"
),
"Exp": "{indent}{o0} = tl.exp({i0})\n",
"Tanh": "{indent}{o0} = tl.libdevice.tanh({i0})\n",
"Tanh": "{indent}{o0} = tl.math.tanh({i0})\n",
"Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n",
"Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n",
"Log": "{indent}{o0} = tl.log({i0})\n",
@ -303,6 +313,9 @@ class TritonCodegen(NodeVisitor):
else:
kwargs["dtype"] = to_dtype.__name__
if op_type == "QuickGelu" or op_type == "QuickGeluGrad":
kwargs["alpha"] = str(node.attributes.get("alpha", 1.702))
if op_type == "Sum":
output_var = kwargs["o0"]
formula = " + ".join([kwargs[f"i{idx}"] for idx in range(len(node.inputs))])

View file

@ -131,6 +131,10 @@ class TypeAndShapeInfer:
"ReduceMax": _infer_reduction,
"ReduceMin": _infer_reduction,
"Sum": _infer_elementwise,
"Gelu": _infer_unary,
"QuickGelu": _infer_unary,
"GeluGrad": _infer_elementwise,
"QuickGeluGrad": _infer_elementwise,
}
@classmethod

View file

@ -5,7 +5,7 @@
from abc import abstractmethod
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np
import sympy
@ -184,14 +184,25 @@ class ComputeNode(IRNode):
Each operator is represented as a ComputeNode.
"""
def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg]):
def __init__(
self,
op_type: str,
inputs: List[TensorArg],
outputs: List[TensorArg],
attributes: Dict[str, Any] = {}, # noqa: B006
):
super().__init__(inputs, outputs)
self._op_type: str = op_type
self._attributes: Dict[str, Any] = attributes
@property
def op_type(self):
return self._op_type
@property
def attributes(self):
return self._attributes
class ReduceNode(ComputeNode):
def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator):

View file

@ -9,7 +9,7 @@ from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple
import sympy
from onnx import NodeProto
from onnx import NodeProto, helper
from ._common import AutotuneConfigs, TensorInfo
from ._ir import (
@ -378,7 +378,10 @@ class GraphLowering:
return DropoutNode(inputs, outputs, offset_calc)
if is_reduction_node(node):
return ReduceNode(op_type, inputs, outputs, offset_calc)
return ComputeNode(op_type, inputs, outputs)
attributes = {}
for attr in node.attribute:
attributes[attr.name] = helper.get_attribute_value(attr)
return ComputeNode(op_type, inputs, outputs, attributes)
def _analyze_kernel_io_list(self):
cross_kernel_inputs = set()

View file

@ -36,6 +36,10 @@ _ELEMENTWISE_OPS = {
"DropoutGrad": {"domain": "com.microsoft", "versions": [1]},
"Identity": {"versions": [13], "is_no_op": True},
"Sum": {"versions": [13]},
"Gelu": {"domain": "com.microsoft", "versions": [1]},
"QuickGelu": {"domain": "com.microsoft", "versions": [1]},
"GeluGrad": {"domain": "com.microsoft", "versions": [1]},
"QuickGeluGrad": {"domain": "com.microsoft", "versions": [1]},
}
_REDUCTION_OPS = {

View file

@ -135,8 +135,31 @@ def _torch_layer_norm(input, weight, bias, **kwargs):
return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias)
def _torch_gelu(input):
return torch.nn.functional.gelu(input)
def _torch_quick_gelu(input, **kwargs):
alpha = kwargs.get("alpha", 1.702)
return input * torch.sigmoid(input * alpha)
def _torch_gelu_grad(dy, x):
alpha = 0.70710678118654752440
beta = 1.12837916709551257390 * 0.70710678118654752440 * 0.5
cdf = 0.5 * (1 + torch.erf(x * alpha))
pdf = beta * torch.exp(x * x * -0.5)
return dy * (cdf + x * pdf)
def _torch_quick_gelu_grad(dy, x, **kwargs):
alpha = kwargs.get("alpha", 1.702)
sigmoid = torch.sigmoid(x * alpha)
return dy * sigmoid * (1.0 + x * alpha * (1.0 - sigmoid))
class TorchFuncExecutor:
_INFER_FUNC_MAP = { # noqa: RUF012
_TORCH_FUNC_MAP = { # noqa: RUF012
"Add": _torch_add,
"Sub": _torch_sub,
"Mul": _torch_mul,
@ -154,13 +177,17 @@ class TorchFuncExecutor:
"ReduceMin": _torch_reduce_min,
"Softmax": _torch_softmax,
"LayerNormalization": _torch_layer_norm,
"Gelu": _torch_gelu,
"QuickGelu": _torch_quick_gelu,
"GeluGrad": _torch_gelu_grad,
"QuickGeluGrad": _torch_quick_gelu_grad,
}
@classmethod
def run(cls, op_type, *torch_tensors, **kwargs):
if op_type not in cls._INFER_FUNC_MAP:
if op_type not in cls._TORCH_FUNC_MAP:
raise NotImplementedError(f"Unsupported op type: {op_type}")
return cls._INFER_FUNC_MAP[op_type](*torch_tensors, **kwargs)
return cls._TORCH_FUNC_MAP[op_type](*torch_tensors, **kwargs)
def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwargs):
@ -169,6 +196,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar
pt_inputs = gen_inputs_func(_onnx_dtype_to_torch_dtype(onnx_dtype))
ort_inputs = copy.deepcopy(pt_inputs)
ort_inputs = [tensor.to(torch.uint8) if tensor.dtype == torch.bool else tensor for tensor in ort_inputs]
if "::" in op_type:
_, op_type = op_type.split("::")
pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs)
model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString()
ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs])
@ -260,13 +289,27 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co
del os.environ["ORTMODULE_TUNING_RESULTS_PATH"]
@pytest.mark.parametrize("op_type", ["Add", "Sub", "Mul", "Div"])
@pytest.mark.parametrize(
"op",
[
("Add", {}),
("Sub", {}),
("Mul", {}),
("Div", {}),
("com.microsoft::GeluGrad", {}),
("com.microsoft::QuickGeluGrad", {}),
("com.microsoft::QuickGeluGrad", {"alpha": 1.0}),
],
)
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])
@pytest.mark.parametrize("input_shapes", [([1024, 2], [1024, 2]), ([2, 3, 3, 3], [3, 1, 3]), ([2049], [1])])
def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes):
def _create_model(op_type, onnx_dtype):
def test_binary_elementwise_op(op, onnx_dtype, input_shapes):
def _create_model(op_type, onnx_dtype, **kwargs):
domain = ""
if "::" in op_type:
domain, op_type = op_type.split("::")
graph = helper.make_graph(
[helper.make_node(op_type, ["X", "Y"], ["Z"], name="test")],
[helper.make_node(op_type, ["X", "Y"], ["Z"], name="test", domain=domain, **kwargs)],
"test",
[
helper.make_tensor_value_info("X", onnx_dtype, None),
@ -282,7 +325,7 @@ def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes):
torch.randn(*input_shapes[1], dtype=dtype, device=DEVICE),
]
_run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs)
_run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1])
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])
@ -303,13 +346,25 @@ def test_sum_op(onnx_dtype, input_shapes):
_run_op_test("Sum", onnx_dtype, _create_model, _gen_inputs)
@pytest.mark.parametrize("op_type", ["Sqrt", "Exp"])
@pytest.mark.parametrize(
"op",
[
("Sqrt", {}),
("Exp", {}),
("com.microsoft::Gelu", {}),
("com.microsoft::QuickGelu", {}),
("com.microsoft::QuickGelu", {"alpha": 1.0}),
],
)
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])
@pytest.mark.parametrize("input_shape", [[1024, 4], [2, 3, 3, 3], [2049, 1]])
def test_unary_elementwise_op(op_type, onnx_dtype, input_shape):
def _create_model(op_type, onnx_dtype):
def test_unary_elementwise_op(op, onnx_dtype, input_shape):
def _create_model(op_type, onnx_dtype, **kwargs):
domain = ""
if "::" in op_type:
domain, op_type = op_type.split("::")
graph = helper.make_graph(
[helper.make_node(op_type, ["X"], ["Y"], name="test")],
[helper.make_node(op_type, ["X"], ["Y"], name="test", domain=domain, **kwargs)],
"test",
[helper.make_tensor_value_info("X", onnx_dtype, None)],
[helper.make_tensor_value_info("Y", onnx_dtype, None)],
@ -319,7 +374,7 @@ def test_unary_elementwise_op(op_type, onnx_dtype, input_shape):
def _gen_inputs(dtype):
return [torch.rand(*input_shape, dtype=dtype, device=DEVICE)]
_run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs)
_run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1])
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])