mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Add Gelu Related Ops to Triton Codegen (#17713)
Add Gelu/QuickGelu/GeluGrad/QuickGeluGrad support to Triton Codegen so that it can be fused with some other connected supported Ops. For example, in llama2, it can be fused with Mul so we will have extra 1-2% perf gain.
This commit is contained in:
parent
a99c965d05
commit
e6aa0fa174
6 changed files with 110 additions and 20 deletions
|
|
@ -263,10 +263,20 @@ class TritonCodegen(NodeVisitor):
|
|||
"Rsqrt": "{indent}{o0} = 1.0 / tl.sqrt({i0})\n",
|
||||
"Cast": "{indent}{o0} = {i0}.to(tl.{dtype})\n",
|
||||
"CastBool": "{indent}{o0} = {i0} != 0\n",
|
||||
"Erf": "{indent}{o0} = tl.libdevice.erf({i0})\n",
|
||||
"Gelu": "{indent}{o0} = (tl.libdevice.erf({i0} / 1.41421356237) + 1.0) * 0.5\n",
|
||||
"Erf": "{indent}{o0} = tl.erf({i0})\n",
|
||||
"Gelu": "{indent}{o0} = {i0} * 0.5 * (tl.math.erf({i0} * 0.70710678118654752440) + 1.0)\n",
|
||||
"QuickGelu": "{indent}{o0} = {i0} * tl.sigmoid({i0} * {alpha})\n",
|
||||
"GeluGrad": (
|
||||
"{indent}{o0} = {i0} * (0.5 * (1.0 + tl.math.erf(0.70710678118654752440 * {i1})) + "
|
||||
"{i1} * 1.12837916709551257390 * 0.70710678118654752440 * 0.5 * tl.exp(-0.5 * {i1} * {i1}))\n"
|
||||
),
|
||||
"QuickGeluGrad": (
|
||||
"{indent}tmp_v = {i1} * {alpha}\n"
|
||||
"{indent}tmp_sigmoid = tl.sigmoid(tmp_v)\n"
|
||||
"{indent}{o0} = {i0} * tmp_sigmoid * (1.0 + tmp_v * (1.0 - tmp_sigmoid))\n"
|
||||
),
|
||||
"Exp": "{indent}{o0} = tl.exp({i0})\n",
|
||||
"Tanh": "{indent}{o0} = tl.libdevice.tanh({i0})\n",
|
||||
"Tanh": "{indent}{o0} = tl.math.tanh({i0})\n",
|
||||
"Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n",
|
||||
"Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n",
|
||||
"Log": "{indent}{o0} = tl.log({i0})\n",
|
||||
|
|
@ -303,6 +313,9 @@ class TritonCodegen(NodeVisitor):
|
|||
else:
|
||||
kwargs["dtype"] = to_dtype.__name__
|
||||
|
||||
if op_type == "QuickGelu" or op_type == "QuickGeluGrad":
|
||||
kwargs["alpha"] = str(node.attributes.get("alpha", 1.702))
|
||||
|
||||
if op_type == "Sum":
|
||||
output_var = kwargs["o0"]
|
||||
formula = " + ".join([kwargs[f"i{idx}"] for idx in range(len(node.inputs))])
|
||||
|
|
|
|||
|
|
@ -131,6 +131,10 @@ class TypeAndShapeInfer:
|
|||
"ReduceMax": _infer_reduction,
|
||||
"ReduceMin": _infer_reduction,
|
||||
"Sum": _infer_elementwise,
|
||||
"Gelu": _infer_unary,
|
||||
"QuickGelu": _infer_unary,
|
||||
"GeluGrad": _infer_elementwise,
|
||||
"QuickGeluGrad": _infer_elementwise,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sympy
|
||||
|
|
@ -184,14 +184,25 @@ class ComputeNode(IRNode):
|
|||
Each operator is represented as a ComputeNode.
|
||||
"""
|
||||
|
||||
def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg]):
|
||||
def __init__(
|
||||
self,
|
||||
op_type: str,
|
||||
inputs: List[TensorArg],
|
||||
outputs: List[TensorArg],
|
||||
attributes: Dict[str, Any] = {}, # noqa: B006
|
||||
):
|
||||
super().__init__(inputs, outputs)
|
||||
self._op_type: str = op_type
|
||||
self._attributes: Dict[str, Any] = attributes
|
||||
|
||||
@property
|
||||
def op_type(self):
|
||||
return self._op_type
|
||||
|
||||
@property
|
||||
def attributes(self):
|
||||
return self._attributes
|
||||
|
||||
|
||||
class ReduceNode(ComputeNode):
|
||||
def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from collections import defaultdict
|
|||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
import sympy
|
||||
from onnx import NodeProto
|
||||
from onnx import NodeProto, helper
|
||||
|
||||
from ._common import AutotuneConfigs, TensorInfo
|
||||
from ._ir import (
|
||||
|
|
@ -378,7 +378,10 @@ class GraphLowering:
|
|||
return DropoutNode(inputs, outputs, offset_calc)
|
||||
if is_reduction_node(node):
|
||||
return ReduceNode(op_type, inputs, outputs, offset_calc)
|
||||
return ComputeNode(op_type, inputs, outputs)
|
||||
attributes = {}
|
||||
for attr in node.attribute:
|
||||
attributes[attr.name] = helper.get_attribute_value(attr)
|
||||
return ComputeNode(op_type, inputs, outputs, attributes)
|
||||
|
||||
def _analyze_kernel_io_list(self):
|
||||
cross_kernel_inputs = set()
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ _ELEMENTWISE_OPS = {
|
|||
"DropoutGrad": {"domain": "com.microsoft", "versions": [1]},
|
||||
"Identity": {"versions": [13], "is_no_op": True},
|
||||
"Sum": {"versions": [13]},
|
||||
"Gelu": {"domain": "com.microsoft", "versions": [1]},
|
||||
"QuickGelu": {"domain": "com.microsoft", "versions": [1]},
|
||||
"GeluGrad": {"domain": "com.microsoft", "versions": [1]},
|
||||
"QuickGeluGrad": {"domain": "com.microsoft", "versions": [1]},
|
||||
}
|
||||
|
||||
_REDUCTION_OPS = {
|
||||
|
|
|
|||
|
|
@ -135,8 +135,31 @@ def _torch_layer_norm(input, weight, bias, **kwargs):
|
|||
return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias)
|
||||
|
||||
|
||||
def _torch_gelu(input):
|
||||
return torch.nn.functional.gelu(input)
|
||||
|
||||
|
||||
def _torch_quick_gelu(input, **kwargs):
|
||||
alpha = kwargs.get("alpha", 1.702)
|
||||
return input * torch.sigmoid(input * alpha)
|
||||
|
||||
|
||||
def _torch_gelu_grad(dy, x):
|
||||
alpha = 0.70710678118654752440
|
||||
beta = 1.12837916709551257390 * 0.70710678118654752440 * 0.5
|
||||
cdf = 0.5 * (1 + torch.erf(x * alpha))
|
||||
pdf = beta * torch.exp(x * x * -0.5)
|
||||
return dy * (cdf + x * pdf)
|
||||
|
||||
|
||||
def _torch_quick_gelu_grad(dy, x, **kwargs):
|
||||
alpha = kwargs.get("alpha", 1.702)
|
||||
sigmoid = torch.sigmoid(x * alpha)
|
||||
return dy * sigmoid * (1.0 + x * alpha * (1.0 - sigmoid))
|
||||
|
||||
|
||||
class TorchFuncExecutor:
|
||||
_INFER_FUNC_MAP = { # noqa: RUF012
|
||||
_TORCH_FUNC_MAP = { # noqa: RUF012
|
||||
"Add": _torch_add,
|
||||
"Sub": _torch_sub,
|
||||
"Mul": _torch_mul,
|
||||
|
|
@ -154,13 +177,17 @@ class TorchFuncExecutor:
|
|||
"ReduceMin": _torch_reduce_min,
|
||||
"Softmax": _torch_softmax,
|
||||
"LayerNormalization": _torch_layer_norm,
|
||||
"Gelu": _torch_gelu,
|
||||
"QuickGelu": _torch_quick_gelu,
|
||||
"GeluGrad": _torch_gelu_grad,
|
||||
"QuickGeluGrad": _torch_quick_gelu_grad,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def run(cls, op_type, *torch_tensors, **kwargs):
|
||||
if op_type not in cls._INFER_FUNC_MAP:
|
||||
if op_type not in cls._TORCH_FUNC_MAP:
|
||||
raise NotImplementedError(f"Unsupported op type: {op_type}")
|
||||
return cls._INFER_FUNC_MAP[op_type](*torch_tensors, **kwargs)
|
||||
return cls._TORCH_FUNC_MAP[op_type](*torch_tensors, **kwargs)
|
||||
|
||||
|
||||
def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwargs):
|
||||
|
|
@ -169,6 +196,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar
|
|||
pt_inputs = gen_inputs_func(_onnx_dtype_to_torch_dtype(onnx_dtype))
|
||||
ort_inputs = copy.deepcopy(pt_inputs)
|
||||
ort_inputs = [tensor.to(torch.uint8) if tensor.dtype == torch.bool else tensor for tensor in ort_inputs]
|
||||
if "::" in op_type:
|
||||
_, op_type = op_type.split("::")
|
||||
pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs)
|
||||
model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString()
|
||||
ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs])
|
||||
|
|
@ -260,13 +289,27 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co
|
|||
del os.environ["ORTMODULE_TUNING_RESULTS_PATH"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op_type", ["Add", "Sub", "Mul", "Div"])
|
||||
@pytest.mark.parametrize(
|
||||
"op",
|
||||
[
|
||||
("Add", {}),
|
||||
("Sub", {}),
|
||||
("Mul", {}),
|
||||
("Div", {}),
|
||||
("com.microsoft::GeluGrad", {}),
|
||||
("com.microsoft::QuickGeluGrad", {}),
|
||||
("com.microsoft::QuickGeluGrad", {"alpha": 1.0}),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])
|
||||
@pytest.mark.parametrize("input_shapes", [([1024, 2], [1024, 2]), ([2, 3, 3, 3], [3, 1, 3]), ([2049], [1])])
|
||||
def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes):
|
||||
def _create_model(op_type, onnx_dtype):
|
||||
def test_binary_elementwise_op(op, onnx_dtype, input_shapes):
|
||||
def _create_model(op_type, onnx_dtype, **kwargs):
|
||||
domain = ""
|
||||
if "::" in op_type:
|
||||
domain, op_type = op_type.split("::")
|
||||
graph = helper.make_graph(
|
||||
[helper.make_node(op_type, ["X", "Y"], ["Z"], name="test")],
|
||||
[helper.make_node(op_type, ["X", "Y"], ["Z"], name="test", domain=domain, **kwargs)],
|
||||
"test",
|
||||
[
|
||||
helper.make_tensor_value_info("X", onnx_dtype, None),
|
||||
|
|
@ -282,7 +325,7 @@ def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes):
|
|||
torch.randn(*input_shapes[1], dtype=dtype, device=DEVICE),
|
||||
]
|
||||
|
||||
_run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs)
|
||||
_run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])
|
||||
|
|
@ -303,13 +346,25 @@ def test_sum_op(onnx_dtype, input_shapes):
|
|||
_run_op_test("Sum", onnx_dtype, _create_model, _gen_inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op_type", ["Sqrt", "Exp"])
|
||||
@pytest.mark.parametrize(
|
||||
"op",
|
||||
[
|
||||
("Sqrt", {}),
|
||||
("Exp", {}),
|
||||
("com.microsoft::Gelu", {}),
|
||||
("com.microsoft::QuickGelu", {}),
|
||||
("com.microsoft::QuickGelu", {"alpha": 1.0}),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])
|
||||
@pytest.mark.parametrize("input_shape", [[1024, 4], [2, 3, 3, 3], [2049, 1]])
|
||||
def test_unary_elementwise_op(op_type, onnx_dtype, input_shape):
|
||||
def _create_model(op_type, onnx_dtype):
|
||||
def test_unary_elementwise_op(op, onnx_dtype, input_shape):
|
||||
def _create_model(op_type, onnx_dtype, **kwargs):
|
||||
domain = ""
|
||||
if "::" in op_type:
|
||||
domain, op_type = op_type.split("::")
|
||||
graph = helper.make_graph(
|
||||
[helper.make_node(op_type, ["X"], ["Y"], name="test")],
|
||||
[helper.make_node(op_type, ["X"], ["Y"], name="test", domain=domain, **kwargs)],
|
||||
"test",
|
||||
[helper.make_tensor_value_info("X", onnx_dtype, None)],
|
||||
[helper.make_tensor_value_info("Y", onnx_dtype, None)],
|
||||
|
|
@ -319,7 +374,7 @@ def test_unary_elementwise_op(op_type, onnx_dtype, input_shape):
|
|||
def _gen_inputs(dtype):
|
||||
return [torch.rand(*input_shape, dtype=dtype, device=DEVICE)]
|
||||
|
||||
_run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs)
|
||||
_run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])
|
||||
|
|
|
|||
Loading…
Reference in a new issue