From e6aa0fa17464e786b5e0a9811de7a04fe86329bf Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 27 Sep 2023 19:57:39 +0800 Subject: [PATCH] Add Gelu Related Ops to Triton Codegen (#17713) Add Gelu/QuickGelu/GeluGrad/QuickGeluGrad support to Triton Codegen so that it can be fused with some other connected supported Ops. For example, in llama2, it can be fused with Mul so we will have extra 1-2% perf gain. --- .../python/training/ort_triton/_codegen.py | 19 ++++- .../python/training/ort_triton/_common.py | 4 + .../python/training/ort_triton/_ir.py | 15 +++- .../python/training/ort_triton/_lowering.py | 7 +- .../python/training/ort_triton/_op_config.py | 4 + .../orttraining_test_ortmodule_triton.py | 81 ++++++++++++++++--- 6 files changed, 110 insertions(+), 20 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index c071f01f87..8e21013da2 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -263,10 +263,20 @@ class TritonCodegen(NodeVisitor): "Rsqrt": "{indent}{o0} = 1.0 / tl.sqrt({i0})\n", "Cast": "{indent}{o0} = {i0}.to(tl.{dtype})\n", "CastBool": "{indent}{o0} = {i0} != 0\n", - "Erf": "{indent}{o0} = tl.libdevice.erf({i0})\n", - "Gelu": "{indent}{o0} = (tl.libdevice.erf({i0} / 1.41421356237) + 1.0) * 0.5\n", + "Erf": "{indent}{o0} = tl.erf({i0})\n", + "Gelu": "{indent}{o0} = {i0} * 0.5 * (tl.math.erf({i0} * 0.70710678118654752440) + 1.0)\n", + "QuickGelu": "{indent}{o0} = {i0} * tl.sigmoid({i0} * {alpha})\n", + "GeluGrad": ( + "{indent}{o0} = {i0} * (0.5 * (1.0 + tl.math.erf(0.70710678118654752440 * {i1})) + " + "{i1} * 1.12837916709551257390 * 0.70710678118654752440 * 0.5 * tl.exp(-0.5 * {i1} * {i1}))\n" + ), + "QuickGeluGrad": ( + "{indent}tmp_v = {i1} * {alpha}\n" + "{indent}tmp_sigmoid = tl.sigmoid(tmp_v)\n" + "{indent}{o0} = {i0} * tmp_sigmoid * (1.0 + tmp_v * (1.0 - tmp_sigmoid))\n" + ), "Exp": "{indent}{o0} = tl.exp({i0})\n", - "Tanh": "{indent}{o0} = tl.libdevice.tanh({i0})\n", + "Tanh": "{indent}{o0} = tl.math.tanh({i0})\n", "Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n", "Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n", "Log": "{indent}{o0} = tl.log({i0})\n", @@ -303,6 +313,9 @@ class TritonCodegen(NodeVisitor): else: kwargs["dtype"] = to_dtype.__name__ + if op_type == "QuickGelu" or op_type == "QuickGeluGrad": + kwargs["alpha"] = str(node.attributes.get("alpha", 1.702)) + if op_type == "Sum": output_var = kwargs["o0"] formula = " + ".join([kwargs[f"i{idx}"] for idx in range(len(node.inputs))]) diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index 6554020242..82ac82cfa2 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -131,6 +131,10 @@ class TypeAndShapeInfer: "ReduceMax": _infer_reduction, "ReduceMin": _infer_reduction, "Sum": _infer_elementwise, + "Gelu": _infer_unary, + "QuickGelu": _infer_unary, + "GeluGrad": _infer_elementwise, + "QuickGeluGrad": _infer_elementwise, } @classmethod diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index 8aa5c1b131..f7d3b31eac 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -5,7 +5,7 @@ from abc import abstractmethod from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np import sympy @@ -184,14 +184,25 @@ class ComputeNode(IRNode): Each operator is represented as a ComputeNode. """ - def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg]): + def __init__( + self, + op_type: str, + inputs: List[TensorArg], + outputs: List[TensorArg], + attributes: Dict[str, Any] = {}, # noqa: B006 + ): super().__init__(inputs, outputs) self._op_type: str = op_type + self._attributes: Dict[str, Any] = attributes @property def op_type(self): return self._op_type + @property + def attributes(self): + return self._attributes + class ReduceNode(ComputeNode): def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator): diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 16db9ab000..5de60e6943 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -9,7 +9,7 @@ from collections import defaultdict from typing import Any, Dict, List, Set, Tuple import sympy -from onnx import NodeProto +from onnx import NodeProto, helper from ._common import AutotuneConfigs, TensorInfo from ._ir import ( @@ -378,7 +378,10 @@ class GraphLowering: return DropoutNode(inputs, outputs, offset_calc) if is_reduction_node(node): return ReduceNode(op_type, inputs, outputs, offset_calc) - return ComputeNode(op_type, inputs, outputs) + attributes = {} + for attr in node.attribute: + attributes[attr.name] = helper.get_attribute_value(attr) + return ComputeNode(op_type, inputs, outputs, attributes) def _analyze_kernel_io_list(self): cross_kernel_inputs = set() diff --git a/orttraining/orttraining/python/training/ort_triton/_op_config.py b/orttraining/orttraining/python/training/ort_triton/_op_config.py index f58d0e1847..7d9af00933 100644 --- a/orttraining/orttraining/python/training/ort_triton/_op_config.py +++ b/orttraining/orttraining/python/training/ort_triton/_op_config.py @@ -36,6 +36,10 @@ _ELEMENTWISE_OPS = { "DropoutGrad": {"domain": "com.microsoft", "versions": [1]}, "Identity": {"versions": [13], "is_no_op": True}, "Sum": {"versions": [13]}, + "Gelu": {"domain": "com.microsoft", "versions": [1]}, + "QuickGelu": {"domain": "com.microsoft", "versions": [1]}, + "GeluGrad": {"domain": "com.microsoft", "versions": [1]}, + "QuickGeluGrad": {"domain": "com.microsoft", "versions": [1]}, } _REDUCTION_OPS = { diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 318de843ef..d205e8f237 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -135,8 +135,31 @@ def _torch_layer_norm(input, weight, bias, **kwargs): return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias) +def _torch_gelu(input): + return torch.nn.functional.gelu(input) + + +def _torch_quick_gelu(input, **kwargs): + alpha = kwargs.get("alpha", 1.702) + return input * torch.sigmoid(input * alpha) + + +def _torch_gelu_grad(dy, x): + alpha = 0.70710678118654752440 + beta = 1.12837916709551257390 * 0.70710678118654752440 * 0.5 + cdf = 0.5 * (1 + torch.erf(x * alpha)) + pdf = beta * torch.exp(x * x * -0.5) + return dy * (cdf + x * pdf) + + +def _torch_quick_gelu_grad(dy, x, **kwargs): + alpha = kwargs.get("alpha", 1.702) + sigmoid = torch.sigmoid(x * alpha) + return dy * sigmoid * (1.0 + x * alpha * (1.0 - sigmoid)) + + class TorchFuncExecutor: - _INFER_FUNC_MAP = { # noqa: RUF012 + _TORCH_FUNC_MAP = { # noqa: RUF012 "Add": _torch_add, "Sub": _torch_sub, "Mul": _torch_mul, @@ -154,13 +177,17 @@ class TorchFuncExecutor: "ReduceMin": _torch_reduce_min, "Softmax": _torch_softmax, "LayerNormalization": _torch_layer_norm, + "Gelu": _torch_gelu, + "QuickGelu": _torch_quick_gelu, + "GeluGrad": _torch_gelu_grad, + "QuickGeluGrad": _torch_quick_gelu_grad, } @classmethod def run(cls, op_type, *torch_tensors, **kwargs): - if op_type not in cls._INFER_FUNC_MAP: + if op_type not in cls._TORCH_FUNC_MAP: raise NotImplementedError(f"Unsupported op type: {op_type}") - return cls._INFER_FUNC_MAP[op_type](*torch_tensors, **kwargs) + return cls._TORCH_FUNC_MAP[op_type](*torch_tensors, **kwargs) def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwargs): @@ -169,6 +196,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar pt_inputs = gen_inputs_func(_onnx_dtype_to_torch_dtype(onnx_dtype)) ort_inputs = copy.deepcopy(pt_inputs) ort_inputs = [tensor.to(torch.uint8) if tensor.dtype == torch.bool else tensor for tensor in ort_inputs] + if "::" in op_type: + _, op_type = op_type.split("::") pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs) model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString() ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) @@ -260,13 +289,27 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co del os.environ["ORTMODULE_TUNING_RESULTS_PATH"] -@pytest.mark.parametrize("op_type", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize( + "op", + [ + ("Add", {}), + ("Sub", {}), + ("Mul", {}), + ("Div", {}), + ("com.microsoft::GeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shapes", [([1024, 2], [1024, 2]), ([2, 3, 3, 3], [3, 1, 3]), ([2049], [1])]) -def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes): - def _create_model(op_type, onnx_dtype): +def test_binary_elementwise_op(op, onnx_dtype, input_shapes): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test")], + [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test", domain=domain, **kwargs)], "test", [ helper.make_tensor_value_info("X", onnx_dtype, None), @@ -282,7 +325,7 @@ def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes): torch.randn(*input_shapes[1], dtype=dtype, device=DEVICE), ] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -303,13 +346,25 @@ def test_sum_op(onnx_dtype, input_shapes): _run_op_test("Sum", onnx_dtype, _create_model, _gen_inputs) -@pytest.mark.parametrize("op_type", ["Sqrt", "Exp"]) +@pytest.mark.parametrize( + "op", + [ + ("Sqrt", {}), + ("Exp", {}), + ("com.microsoft::Gelu", {}), + ("com.microsoft::QuickGelu", {}), + ("com.microsoft::QuickGelu", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shape", [[1024, 4], [2, 3, 3, 3], [2049, 1]]) -def test_unary_elementwise_op(op_type, onnx_dtype, input_shape): - def _create_model(op_type, onnx_dtype): +def test_unary_elementwise_op(op, onnx_dtype, input_shape): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X"], ["Y"], name="test")], + [helper.make_node(op_type, ["X"], ["Y"], name="test", domain=domain, **kwargs)], "test", [helper.make_tensor_value_info("X", onnx_dtype, None)], [helper.make_tensor_value_info("Y", onnx_dtype, None)], @@ -319,7 +374,7 @@ def test_unary_elementwise_op(op_type, onnx_dtype, input_shape): def _gen_inputs(dtype): return [torch.rand(*input_shape, dtype=dtype, device=DEVICE)] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])