diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index c071f01f87..8e21013da2 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -263,10 +263,20 @@ class TritonCodegen(NodeVisitor): "Rsqrt": "{indent}{o0} = 1.0 / tl.sqrt({i0})\n", "Cast": "{indent}{o0} = {i0}.to(tl.{dtype})\n", "CastBool": "{indent}{o0} = {i0} != 0\n", - "Erf": "{indent}{o0} = tl.libdevice.erf({i0})\n", - "Gelu": "{indent}{o0} = (tl.libdevice.erf({i0} / 1.41421356237) + 1.0) * 0.5\n", + "Erf": "{indent}{o0} = tl.erf({i0})\n", + "Gelu": "{indent}{o0} = {i0} * 0.5 * (tl.math.erf({i0} * 0.70710678118654752440) + 1.0)\n", + "QuickGelu": "{indent}{o0} = {i0} * tl.sigmoid({i0} * {alpha})\n", + "GeluGrad": ( + "{indent}{o0} = {i0} * (0.5 * (1.0 + tl.math.erf(0.70710678118654752440 * {i1})) + " + "{i1} * 1.12837916709551257390 * 0.70710678118654752440 * 0.5 * tl.exp(-0.5 * {i1} * {i1}))\n" + ), + "QuickGeluGrad": ( + "{indent}tmp_v = {i1} * {alpha}\n" + "{indent}tmp_sigmoid = tl.sigmoid(tmp_v)\n" + "{indent}{o0} = {i0} * tmp_sigmoid * (1.0 + tmp_v * (1.0 - tmp_sigmoid))\n" + ), "Exp": "{indent}{o0} = tl.exp({i0})\n", - "Tanh": "{indent}{o0} = tl.libdevice.tanh({i0})\n", + "Tanh": "{indent}{o0} = tl.math.tanh({i0})\n", "Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n", "Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n", "Log": "{indent}{o0} = tl.log({i0})\n", @@ -303,6 +313,9 @@ class TritonCodegen(NodeVisitor): else: kwargs["dtype"] = to_dtype.__name__ + if op_type == "QuickGelu" or op_type == "QuickGeluGrad": + kwargs["alpha"] = str(node.attributes.get("alpha", 1.702)) + if op_type == "Sum": output_var = kwargs["o0"] formula = " + ".join([kwargs[f"i{idx}"] for idx in range(len(node.inputs))]) diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index 6554020242..82ac82cfa2 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -131,6 +131,10 @@ class TypeAndShapeInfer: "ReduceMax": _infer_reduction, "ReduceMin": _infer_reduction, "Sum": _infer_elementwise, + "Gelu": _infer_unary, + "QuickGelu": _infer_unary, + "GeluGrad": _infer_elementwise, + "QuickGeluGrad": _infer_elementwise, } @classmethod diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index 8aa5c1b131..f7d3b31eac 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -5,7 +5,7 @@ from abc import abstractmethod from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np import sympy @@ -184,14 +184,25 @@ class ComputeNode(IRNode): Each operator is represented as a ComputeNode. """ - def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg]): + def __init__( + self, + op_type: str, + inputs: List[TensorArg], + outputs: List[TensorArg], + attributes: Dict[str, Any] = {}, # noqa: B006 + ): super().__init__(inputs, outputs) self._op_type: str = op_type + self._attributes: Dict[str, Any] = attributes @property def op_type(self): return self._op_type + @property + def attributes(self): + return self._attributes + class ReduceNode(ComputeNode): def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator): diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 16db9ab000..5de60e6943 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -9,7 +9,7 @@ from collections import defaultdict from typing import Any, Dict, List, Set, Tuple import sympy -from onnx import NodeProto +from onnx import NodeProto, helper from ._common import AutotuneConfigs, TensorInfo from ._ir import ( @@ -378,7 +378,10 @@ class GraphLowering: return DropoutNode(inputs, outputs, offset_calc) if is_reduction_node(node): return ReduceNode(op_type, inputs, outputs, offset_calc) - return ComputeNode(op_type, inputs, outputs) + attributes = {} + for attr in node.attribute: + attributes[attr.name] = helper.get_attribute_value(attr) + return ComputeNode(op_type, inputs, outputs, attributes) def _analyze_kernel_io_list(self): cross_kernel_inputs = set() diff --git a/orttraining/orttraining/python/training/ort_triton/_op_config.py b/orttraining/orttraining/python/training/ort_triton/_op_config.py index f58d0e1847..7d9af00933 100644 --- a/orttraining/orttraining/python/training/ort_triton/_op_config.py +++ b/orttraining/orttraining/python/training/ort_triton/_op_config.py @@ -36,6 +36,10 @@ _ELEMENTWISE_OPS = { "DropoutGrad": {"domain": "com.microsoft", "versions": [1]}, "Identity": {"versions": [13], "is_no_op": True}, "Sum": {"versions": [13]}, + "Gelu": {"domain": "com.microsoft", "versions": [1]}, + "QuickGelu": {"domain": "com.microsoft", "versions": [1]}, + "GeluGrad": {"domain": "com.microsoft", "versions": [1]}, + "QuickGeluGrad": {"domain": "com.microsoft", "versions": [1]}, } _REDUCTION_OPS = { diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 318de843ef..d205e8f237 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -135,8 +135,31 @@ def _torch_layer_norm(input, weight, bias, **kwargs): return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias) +def _torch_gelu(input): + return torch.nn.functional.gelu(input) + + +def _torch_quick_gelu(input, **kwargs): + alpha = kwargs.get("alpha", 1.702) + return input * torch.sigmoid(input * alpha) + + +def _torch_gelu_grad(dy, x): + alpha = 0.70710678118654752440 + beta = 1.12837916709551257390 * 0.70710678118654752440 * 0.5 + cdf = 0.5 * (1 + torch.erf(x * alpha)) + pdf = beta * torch.exp(x * x * -0.5) + return dy * (cdf + x * pdf) + + +def _torch_quick_gelu_grad(dy, x, **kwargs): + alpha = kwargs.get("alpha", 1.702) + sigmoid = torch.sigmoid(x * alpha) + return dy * sigmoid * (1.0 + x * alpha * (1.0 - sigmoid)) + + class TorchFuncExecutor: - _INFER_FUNC_MAP = { # noqa: RUF012 + _TORCH_FUNC_MAP = { # noqa: RUF012 "Add": _torch_add, "Sub": _torch_sub, "Mul": _torch_mul, @@ -154,13 +177,17 @@ class TorchFuncExecutor: "ReduceMin": _torch_reduce_min, "Softmax": _torch_softmax, "LayerNormalization": _torch_layer_norm, + "Gelu": _torch_gelu, + "QuickGelu": _torch_quick_gelu, + "GeluGrad": _torch_gelu_grad, + "QuickGeluGrad": _torch_quick_gelu_grad, } @classmethod def run(cls, op_type, *torch_tensors, **kwargs): - if op_type not in cls._INFER_FUNC_MAP: + if op_type not in cls._TORCH_FUNC_MAP: raise NotImplementedError(f"Unsupported op type: {op_type}") - return cls._INFER_FUNC_MAP[op_type](*torch_tensors, **kwargs) + return cls._TORCH_FUNC_MAP[op_type](*torch_tensors, **kwargs) def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwargs): @@ -169,6 +196,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar pt_inputs = gen_inputs_func(_onnx_dtype_to_torch_dtype(onnx_dtype)) ort_inputs = copy.deepcopy(pt_inputs) ort_inputs = [tensor.to(torch.uint8) if tensor.dtype == torch.bool else tensor for tensor in ort_inputs] + if "::" in op_type: + _, op_type = op_type.split("::") pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs) model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString() ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) @@ -260,13 +289,27 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co del os.environ["ORTMODULE_TUNING_RESULTS_PATH"] -@pytest.mark.parametrize("op_type", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize( + "op", + [ + ("Add", {}), + ("Sub", {}), + ("Mul", {}), + ("Div", {}), + ("com.microsoft::GeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shapes", [([1024, 2], [1024, 2]), ([2, 3, 3, 3], [3, 1, 3]), ([2049], [1])]) -def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes): - def _create_model(op_type, onnx_dtype): +def test_binary_elementwise_op(op, onnx_dtype, input_shapes): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test")], + [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test", domain=domain, **kwargs)], "test", [ helper.make_tensor_value_info("X", onnx_dtype, None), @@ -282,7 +325,7 @@ def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes): torch.randn(*input_shapes[1], dtype=dtype, device=DEVICE), ] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -303,13 +346,25 @@ def test_sum_op(onnx_dtype, input_shapes): _run_op_test("Sum", onnx_dtype, _create_model, _gen_inputs) -@pytest.mark.parametrize("op_type", ["Sqrt", "Exp"]) +@pytest.mark.parametrize( + "op", + [ + ("Sqrt", {}), + ("Exp", {}), + ("com.microsoft::Gelu", {}), + ("com.microsoft::QuickGelu", {}), + ("com.microsoft::QuickGelu", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shape", [[1024, 4], [2, 3, 3, 3], [2049, 1]]) -def test_unary_elementwise_op(op_type, onnx_dtype, input_shape): - def _create_model(op_type, onnx_dtype): +def test_unary_elementwise_op(op, onnx_dtype, input_shape): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X"], ["Y"], name="test")], + [helper.make_node(op_type, ["X"], ["Y"], name="test", domain=domain, **kwargs)], "test", [helper.make_tensor_value_info("X", onnx_dtype, None)], [helper.make_tensor_value_info("Y", onnx_dtype, None)], @@ -319,7 +374,7 @@ def test_unary_elementwise_op(op_type, onnx_dtype, input_shape): def _gen_inputs(dtype): return [torch.rand(*input_shape, dtype=dtype, device=DEVICE)] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16])