From c47f446f25cf4dc97931f9e3326cf0d3545ed02b Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 18 Apr 2024 17:15:11 +0800 Subject: [PATCH] Support BFloat16 for Triton Codegen (#20353) Previous implementation used numpy array and numpy data_type to store constant value and data type, which is not support BFloat16 natively. This PR is to switch to use torch tensor which supports BFloat16. --- .../python/training/ort_triton/_codegen.py | 11 ++++---- .../python/training/ort_triton/_decompose.py | 7 +++-- .../python/training/ort_triton/_ir.py | 25 +++++++++-------- .../python/training/ort_triton/_lowering.py | 10 +++---- .../training/ort_triton/_sorted_graph.py | 11 ++++---- .../python/training/ort_triton/_utils.py | 21 ++++++++++----- .../training/ort_triton/kernel/_slice_scel.py | 8 +++--- .../orttraining_test_ortmodule_triton.py | 27 ------------------- 8 files changed, 48 insertions(+), 72 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index 9c7214f467..9a447d8019 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -14,7 +14,6 @@ Mostly, Nodes are classified into two categories: from typing import Tuple -import numpy as np import sympy import torch from sympy.codegen.rewriting import create_expand_pow_optimization @@ -316,14 +315,14 @@ class TritonCodegen(NodeVisitor): op_type = "Sqrt" if op_type == "Cast": - from_dtype = node.inputs[0].dtype.type - to_dtype = node.outputs[0].dtype.type + from_dtype = node.inputs[0].dtype + to_dtype = node.outputs[0].dtype if from_dtype == to_dtype or is_number(kwargs["i0"]): op_type = "Identity" - elif to_dtype == np.bool_: + elif to_dtype == torch.bool: op_type = "CastBool" else: - kwargs["dtype"] = to_dtype.__name__ + kwargs["dtype"] = str(to_dtype)[6:] # Remove "torch." prefix. if op_type == "QuickGelu" or op_type == "QuickGeluGrad": kwargs["alpha"] = str(node.attributes.get("alpha", 1.702)) @@ -473,7 +472,7 @@ class TritonCodegen(NodeVisitor): code_buffer += "\n" # Allocate output tensor. for output in kernel_node.outputs: - torch_dtype = torch.from_numpy(np.zeros(1, dtype=output.dtype)).dtype + torch_dtype = output.dtype # Workaround for DLPack which doesn't support bool. if torch_dtype == torch.bool: torch_dtype = torch.uint8 diff --git a/orttraining/orttraining/python/training/ort_triton/_decompose.py b/orttraining/orttraining/python/training/ort_triton/_decompose.py index ffd20b09b4..c1ded3975d 100644 --- a/orttraining/orttraining/python/training/ort_triton/_decompose.py +++ b/orttraining/orttraining/python/training/ort_triton/_decompose.py @@ -10,11 +10,10 @@ Decompose a complicated op into a series of simple ops. from typing import List -import numpy as np import sympy from onnx import GraphProto, NodeProto, TensorProto, helper -from ._utils import get_attribute, get_reduce_info, to_numpy_type +from ._utils import get_attribute, get_reduce_info def _is_half_dtype(dtype: int): @@ -132,7 +131,7 @@ class DecomposeDispatch: if axis < 0: axis += rank axes = list(range(axis, rank)) - epsilon_tensor = helper.make_tensor(name="epsilon_const", data_type=xdtype, dims=(1,), vals=np.array([epsilon])) + epsilon_tensor = helper.make_tensor(name="epsilon_const", data_type=xdtype, dims=(1,), vals=[epsilon]) const_node, const_out = self._new_node(node_name, "Constant", [], value=epsilon_tensor) reducemean_node, reducemean_out = self._new_node(node_name, "ReduceMean", [x], outputs=[mean], axes=axes) sub_node, sub_out = self._new_node(node_name, "Sub", [x, reducemean_out]) @@ -371,7 +370,7 @@ class DecomposeDispatch: name=f"{node_name}_denominator", dims=(), data_type=dtype, - vals=np.array([denominator], dtype=to_numpy_type(dtype)), + vals=[denominator], ) denominator_node, denominator_out = self._new_node(node_name, "Constant", [], value=denominator_tensor) div_node, _ = self._new_node(node_name, "Div", [sum_out, denominator_out], outputs=[y]) diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index a963d30a9e..23abb082c2 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -7,12 +7,12 @@ from abc import abstractmethod from collections import defaultdict from typing import Any, Dict, List, Optional, Set, Tuple -import numpy as np import sympy +import torch from ._common import AutotuneConfigs, CodeBuffer, CodegenContext, NodeVisitor, TensorInfo from ._sympy_utils import parse_shape -from ._utils import gen_unique_name, gen_variable_name, sort_reduce_axes, to_numpy_type +from ._utils import gen_unique_name, gen_variable_name, sort_reduce_axes, to_torch_dtype class TensorArg: @@ -22,15 +22,15 @@ class TensorArg: If it's constant (initializer or constant node), it also contains the data in numpy array. """ - def __init__(self, name: str, tensor_info: Optional[TensorInfo] = None, data: Optional[np.ndarray] = None): + def __init__(self, name: str, tensor_info: Optional[TensorInfo] = None, data: Optional[torch.Tensor] = None): self._name: str = name - self._data: Optional[np.ndarray] = data + self._data: Optional[torch.Tensor] = data if data is not None: - self._dtype: np.dtype = data.dtype + self._dtype: torch.dtype = data.dtype self._shape: List[sympy.Expr] = parse_shape(list(data.shape)) else: assert tensor_info is not None - self._dtype: np.dtype = to_numpy_type(tensor_info.dtype) + self._dtype: torch.dtype = to_torch_dtype(tensor_info.dtype) self._shape: List[sympy.Expr] = tensor_info.shape self.cross_kernels: bool = False @@ -39,7 +39,7 @@ class TensorArg: return self._name @property - def dtype(self) -> np.dtype: + def dtype(self) -> torch.dtype: return self._dtype @property @@ -47,7 +47,7 @@ class TensorArg: return self._shape @property - def data(self) -> Optional[np.ndarray]: + def data(self) -> Optional[torch.Tensor]: return self._data @@ -328,11 +328,10 @@ class KernelNode(IRNode): self.var_map[name] = gen_variable_name(name, "c", existing_names) if tensor_arg.data is not None: value = tensor_arg.data - if value is not None: - assert value.size == 1, f"unsupported constant array {value}" - variable_name = self.var_map[name] - assert variable_name not in self.var_map - self.var_map[variable_name] = str(np.array(value.item(), value.dtype)) + assert value.numel() == 1, f"unsupported constant {value} which has more than one element." + variable_name = self.var_map[name] + assert variable_name not in self.var_map + self.var_map[variable_name] = str(value.item()) class ElementwiseKernelNode(KernelNode): diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 4b580a0cc8..7253c7935a 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -28,7 +28,7 @@ from ._ir import ( ) from ._op_config import is_reduction_node from ._sorted_graph import SortedGraph -from ._utils import get_reduce_info, sort_reduce_axes, to_numpy_array +from ._utils import get_reduce_info, sort_reduce_axes, to_torch_tensor class NodeGroup: @@ -245,10 +245,10 @@ class GraphLowering: self._module_outputs = [TensorArg(output.name, self._node_arg_infos[output.name]) for output in graph.output] self._module_output_names = set(arg.name for arg in self._module_outputs) for initializer in graph.initializer: - data = to_numpy_array(initializer) + data = to_torch_tensor(initializer) self._module_constants.append(TensorArg(initializer.name, data=data)) for const_node in self._sorted_graph.const_nodes: - data = to_numpy_array(const_node) + data = to_torch_tensor(const_node) self._module_constants.append(TensorArg(const_node.output[0], data=data)) self._module_constant_names = set(arg.name for arg in self._module_constants) self._tensor_args = dict( @@ -415,7 +415,7 @@ class GraphLowering: for idx in range(cur, nxt): for input in sub_nodes[idx].inputs: if input.name in kernel_node.constants or input.name in input_names: - if (input.data is not None and input.data.size == 1) or input.name in load_cache: + if (input.data is not None and input.data.numel() == 1) or input.name in load_cache: continue load_nodes.append(IONode(input, kernel_node.offset_calc, True)) load_cache.add(input.name) @@ -432,7 +432,7 @@ class GraphLowering: for reduce_node in sub_nodes[nxt].reduce_nodes: input = reduce_node.inputs[0] if input.name in kernel_node.constants or input.name in input_names: - if (input.data is not None and input.data.size == 1) or input.name in load_cache: + if (input.data is not None and input.data.numel() == 1) or input.name in load_cache: continue load_nodes.append(IONode(input, kernel_node.offset_calc, True)) load_cache.add(input.name) diff --git a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py index 32e54d0868..d67a1c1665 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py +++ b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py @@ -7,7 +7,6 @@ import copy import itertools from typing import Dict, List, Set -import numpy as np import onnx import sympy from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper @@ -16,7 +15,7 @@ from ._common import SymbolicDSU, TensorInfo, TypeAndShapeInfer from ._decompose import DecomposeDispatch from ._op_config import is_elementwise_node from ._sympy_utils import parse_shape -from ._utils import get_attribute, to_numpy_array, topological_sort +from ._utils import get_attribute, to_torch_tensor, topological_sort class SortedGraph: @@ -58,7 +57,7 @@ class SortedGraph: for initializer in self._graph.initializer: self._node_arg_infos[initializer.name] = TensorInfo( initializer.data_type, - parse_shape(list(to_numpy_array(initializer).shape)), + parse_shape(list(initializer.dims)), ) # Decompose complex operators. @@ -97,13 +96,13 @@ class SortedGraph: constants = [] for idx, initializer in enumerate(self._sorted_initializers): - data_str = np.array2string(to_numpy_array(initializer), separator=",").replace("\n", "").replace(" ", "") + data_str = str(to_torch_tensor(initializer).tolist()).replace("\n", "").replace(" ", "") constants.append(f"({initializer.data_type},{data_str})") name_map[initializer.name] = f"c{idx}" for idx, node in enumerate(self._const_nodes): value_attr = get_attribute(node, "value") - data_str = np.array2string(to_numpy_array(value_attr), separator=",").replace("\n", "").replace(" ", "") + data_str = str(to_torch_tensor(value_attr).tolist()).replace("\n", "").replace(" ", "") constants.append(f"({value_attr.data_type},{data_str})") name_map[node.output[0]] = f"c{idx + len(self._sorted_initializers)}" constants_str = ",".join(constants) @@ -181,7 +180,7 @@ class SortedGraph: value_attr = get_attribute(node, "value") self._node_arg_infos[node.output[0]] = TensorInfo( value_attr.data_type, - parse_shape(list(to_numpy_array(value_attr).shape)), + parse_shape(list(value_attr.dims)), ) else: input_infos = [] diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index 877eacc0b7..e39a668bd0 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -9,8 +9,8 @@ from collections import defaultdict from typing import Any, List, Tuple import numpy as np +import torch from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper -from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE def gen_unique_name(prefix: str) -> str: @@ -73,17 +73,24 @@ def get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) -> return default_value -# Convert Constant node or TensorProto to numpy array. -def to_numpy_array(node: Any) -> np.ndarray: +# Convert Constant node or TensorProto to torch.Tensor. +def to_torch_tensor(node: Any) -> torch.Tensor: tensor = node if isinstance(node, NodeProto): tensor = get_attribute(node, "value") assert isinstance(tensor, TensorProto) - return numpy_helper.to_array(tensor) + torch_tensor = torch.from_numpy(numpy_helper.to_array(tensor)) + # numpy does not support bfloat16 and create a float32 tensor instead. + if tensor.data_type == TensorProto.BFLOAT16: + torch_tensor = torch_tensor.to(torch.bfloat16) + return torch_tensor -def to_numpy_type(tensor_type: TensorProto.DataType) -> np.dtype: - return TENSOR_TYPE_TO_NP_TYPE[tensor_type] if not isinstance(tensor_type, np.dtype) else tensor_type +def to_torch_dtype(tensor_type: TensorProto.DataType) -> torch.dtype: + # Native numpy does not support bfloat16. + if tensor_type == TensorProto.BFLOAT16: + return torch.bfloat16 + return torch.from_numpy(np.zeros(1, dtype=helper.tensor_dtype_to_np_dtype(tensor_type))).dtype # Generate a unique variable name based on the node arg name. @@ -133,7 +140,7 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl axes_initializer = initializer break assert axes_initializer is not None - axes = to_numpy_array(axes_initializer).tolist() + axes = to_torch_tensor(axes_initializer).tolist() if axes is None: axes = list(range(input_rank)) if noop_with_empty_axes == 0 else [] axes = sort_reduce_axes(axes, input_rank, check_contiguous=False) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py index fb7ddc6890..b2d9ca9c7f 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py @@ -13,7 +13,7 @@ from onnx import TensorProto, helper from onnxruntime.training.ortmodule import register_graph_optimizer -from .._utils import get_attribute, to_numpy_array +from .._utils import get_attribute, to_torch_tensor @triton.jit @@ -212,7 +212,7 @@ def _get_constant(graph, arg): initializer = init if initializer is None: return None - return to_numpy_array(initializer) + return to_torch_tensor(initializer).tolist() def _check_slice(graph, node, start, end, axis, step): @@ -224,9 +224,9 @@ def _check_slice(graph, node, start, end, axis, step): axis += rank for idx, value in enumerate([start, end, axis, step]): constant = _get_constant(graph, node.input[idx + 1]) - if constant is None or constant.size != 1: + if constant is None or len(constant) != 1: return False - constant_value = constant.item() + constant_value = constant[0] if idx == 2 and constant_value < 0: constant_value += rank if constant_value != value: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 922f5c6965..0c381d70ca 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -12,7 +12,6 @@ import onnx import pytest import torch from onnx import TensorProto, helper -from packaging.version import Version from torch._C import _from_dlpack from torch.utils.dlpack import to_dlpack @@ -843,32 +842,6 @@ def test_slice_scel_module(dtype, has_sum): _run_module_test(NeuralNetSliceScel, dtype, _gen_inputs, 2) -@pytest.mark.skipif( - Version(torch.__version__) < Version("2.1"), reason="PyTorch has scaled_dot_product_attention since 2.1." -) -def test_scaled_dot_product_attention_module(): - class NeuralNetScaledDotProductAttention(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) - self.linear2 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) - self.linear3 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) - - def forward(self, q, k, v): - return torch.nn.functional.scaled_dot_product_attention( - self.linear1(q), self.linear2(k), self.linear3(v) - ).to(torch.float16) - - def _gen_inputs(dtype): - return [ - (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), - (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), - (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), - ] - - _run_module_test(NeuralNetScaledDotProductAttention, torch.float16, _gen_inputs, 3) - - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("input_shapes", [([128, 64], [64, 64]), ([16, 64, 128], [16, 128, 64])]) def test_matmul_tunable_op(dtype, input_shapes):