Support BFloat16 for Triton Codegen (#20353)

Previous implementation used numpy array and numpy data_type to store
constant value and data type, which is not support BFloat16 natively.
This PR is to switch to use torch tensor which supports BFloat16.
This commit is contained in:
Vincent Wang 2024-04-18 17:15:11 +08:00 committed by GitHub
parent da86f6f408
commit c47f446f25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 48 additions and 72 deletions

View file

@ -14,7 +14,6 @@ Mostly, Nodes are classified into two categories:
from typing import Tuple
import numpy as np
import sympy
import torch
from sympy.codegen.rewriting import create_expand_pow_optimization
@ -316,14 +315,14 @@ class TritonCodegen(NodeVisitor):
op_type = "Sqrt"
if op_type == "Cast":
from_dtype = node.inputs[0].dtype.type
to_dtype = node.outputs[0].dtype.type
from_dtype = node.inputs[0].dtype
to_dtype = node.outputs[0].dtype
if from_dtype == to_dtype or is_number(kwargs["i0"]):
op_type = "Identity"
elif to_dtype == np.bool_:
elif to_dtype == torch.bool:
op_type = "CastBool"
else:
kwargs["dtype"] = to_dtype.__name__
kwargs["dtype"] = str(to_dtype)[6:] # Remove "torch." prefix.
if op_type == "QuickGelu" or op_type == "QuickGeluGrad":
kwargs["alpha"] = str(node.attributes.get("alpha", 1.702))
@ -473,7 +472,7 @@ class TritonCodegen(NodeVisitor):
code_buffer += "\n"
# Allocate output tensor.
for output in kernel_node.outputs:
torch_dtype = torch.from_numpy(np.zeros(1, dtype=output.dtype)).dtype
torch_dtype = output.dtype
# Workaround for DLPack which doesn't support bool.
if torch_dtype == torch.bool:
torch_dtype = torch.uint8

View file

@ -10,11 +10,10 @@ Decompose a complicated op into a series of simple ops.
from typing import List
import numpy as np
import sympy
from onnx import GraphProto, NodeProto, TensorProto, helper
from ._utils import get_attribute, get_reduce_info, to_numpy_type
from ._utils import get_attribute, get_reduce_info
def _is_half_dtype(dtype: int):
@ -132,7 +131,7 @@ class DecomposeDispatch:
if axis < 0:
axis += rank
axes = list(range(axis, rank))
epsilon_tensor = helper.make_tensor(name="epsilon_const", data_type=xdtype, dims=(1,), vals=np.array([epsilon]))
epsilon_tensor = helper.make_tensor(name="epsilon_const", data_type=xdtype, dims=(1,), vals=[epsilon])
const_node, const_out = self._new_node(node_name, "Constant", [], value=epsilon_tensor)
reducemean_node, reducemean_out = self._new_node(node_name, "ReduceMean", [x], outputs=[mean], axes=axes)
sub_node, sub_out = self._new_node(node_name, "Sub", [x, reducemean_out])
@ -371,7 +370,7 @@ class DecomposeDispatch:
name=f"{node_name}_denominator",
dims=(),
data_type=dtype,
vals=np.array([denominator], dtype=to_numpy_type(dtype)),
vals=[denominator],
)
denominator_node, denominator_out = self._new_node(node_name, "Constant", [], value=denominator_tensor)
div_node, _ = self._new_node(node_name, "Div", [sum_out, denominator_out], outputs=[y])

View file

@ -7,12 +7,12 @@ from abc import abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np
import sympy
import torch
from ._common import AutotuneConfigs, CodeBuffer, CodegenContext, NodeVisitor, TensorInfo
from ._sympy_utils import parse_shape
from ._utils import gen_unique_name, gen_variable_name, sort_reduce_axes, to_numpy_type
from ._utils import gen_unique_name, gen_variable_name, sort_reduce_axes, to_torch_dtype
class TensorArg:
@ -22,15 +22,15 @@ class TensorArg:
If it's constant (initializer or constant node), it also contains the data in numpy array.
"""
def __init__(self, name: str, tensor_info: Optional[TensorInfo] = None, data: Optional[np.ndarray] = None):
def __init__(self, name: str, tensor_info: Optional[TensorInfo] = None, data: Optional[torch.Tensor] = None):
self._name: str = name
self._data: Optional[np.ndarray] = data
self._data: Optional[torch.Tensor] = data
if data is not None:
self._dtype: np.dtype = data.dtype
self._dtype: torch.dtype = data.dtype
self._shape: List[sympy.Expr] = parse_shape(list(data.shape))
else:
assert tensor_info is not None
self._dtype: np.dtype = to_numpy_type(tensor_info.dtype)
self._dtype: torch.dtype = to_torch_dtype(tensor_info.dtype)
self._shape: List[sympy.Expr] = tensor_info.shape
self.cross_kernels: bool = False
@ -39,7 +39,7 @@ class TensorArg:
return self._name
@property
def dtype(self) -> np.dtype:
def dtype(self) -> torch.dtype:
return self._dtype
@property
@ -47,7 +47,7 @@ class TensorArg:
return self._shape
@property
def data(self) -> Optional[np.ndarray]:
def data(self) -> Optional[torch.Tensor]:
return self._data
@ -328,11 +328,10 @@ class KernelNode(IRNode):
self.var_map[name] = gen_variable_name(name, "c", existing_names)
if tensor_arg.data is not None:
value = tensor_arg.data
if value is not None:
assert value.size == 1, f"unsupported constant array {value}"
variable_name = self.var_map[name]
assert variable_name not in self.var_map
self.var_map[variable_name] = str(np.array(value.item(), value.dtype))
assert value.numel() == 1, f"unsupported constant {value} which has more than one element."
variable_name = self.var_map[name]
assert variable_name not in self.var_map
self.var_map[variable_name] = str(value.item())
class ElementwiseKernelNode(KernelNode):

View file

@ -28,7 +28,7 @@ from ._ir import (
)
from ._op_config import is_reduction_node
from ._sorted_graph import SortedGraph
from ._utils import get_reduce_info, sort_reduce_axes, to_numpy_array
from ._utils import get_reduce_info, sort_reduce_axes, to_torch_tensor
class NodeGroup:
@ -245,10 +245,10 @@ class GraphLowering:
self._module_outputs = [TensorArg(output.name, self._node_arg_infos[output.name]) for output in graph.output]
self._module_output_names = set(arg.name for arg in self._module_outputs)
for initializer in graph.initializer:
data = to_numpy_array(initializer)
data = to_torch_tensor(initializer)
self._module_constants.append(TensorArg(initializer.name, data=data))
for const_node in self._sorted_graph.const_nodes:
data = to_numpy_array(const_node)
data = to_torch_tensor(const_node)
self._module_constants.append(TensorArg(const_node.output[0], data=data))
self._module_constant_names = set(arg.name for arg in self._module_constants)
self._tensor_args = dict(
@ -415,7 +415,7 @@ class GraphLowering:
for idx in range(cur, nxt):
for input in sub_nodes[idx].inputs:
if input.name in kernel_node.constants or input.name in input_names:
if (input.data is not None and input.data.size == 1) or input.name in load_cache:
if (input.data is not None and input.data.numel() == 1) or input.name in load_cache:
continue
load_nodes.append(IONode(input, kernel_node.offset_calc, True))
load_cache.add(input.name)
@ -432,7 +432,7 @@ class GraphLowering:
for reduce_node in sub_nodes[nxt].reduce_nodes:
input = reduce_node.inputs[0]
if input.name in kernel_node.constants or input.name in input_names:
if (input.data is not None and input.data.size == 1) or input.name in load_cache:
if (input.data is not None and input.data.numel() == 1) or input.name in load_cache:
continue
load_nodes.append(IONode(input, kernel_node.offset_calc, True))
load_cache.add(input.name)

View file

@ -7,7 +7,6 @@ import copy
import itertools
from typing import Dict, List, Set
import numpy as np
import onnx
import sympy
from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper
@ -16,7 +15,7 @@ from ._common import SymbolicDSU, TensorInfo, TypeAndShapeInfer
from ._decompose import DecomposeDispatch
from ._op_config import is_elementwise_node
from ._sympy_utils import parse_shape
from ._utils import get_attribute, to_numpy_array, topological_sort
from ._utils import get_attribute, to_torch_tensor, topological_sort
class SortedGraph:
@ -58,7 +57,7 @@ class SortedGraph:
for initializer in self._graph.initializer:
self._node_arg_infos[initializer.name] = TensorInfo(
initializer.data_type,
parse_shape(list(to_numpy_array(initializer).shape)),
parse_shape(list(initializer.dims)),
)
# Decompose complex operators.
@ -97,13 +96,13 @@ class SortedGraph:
constants = []
for idx, initializer in enumerate(self._sorted_initializers):
data_str = np.array2string(to_numpy_array(initializer), separator=",").replace("\n", "").replace(" ", "")
data_str = str(to_torch_tensor(initializer).tolist()).replace("\n", "").replace(" ", "")
constants.append(f"({initializer.data_type},{data_str})")
name_map[initializer.name] = f"c{idx}"
for idx, node in enumerate(self._const_nodes):
value_attr = get_attribute(node, "value")
data_str = np.array2string(to_numpy_array(value_attr), separator=",").replace("\n", "").replace(" ", "")
data_str = str(to_torch_tensor(value_attr).tolist()).replace("\n", "").replace(" ", "")
constants.append(f"({value_attr.data_type},{data_str})")
name_map[node.output[0]] = f"c{idx + len(self._sorted_initializers)}"
constants_str = ",".join(constants)
@ -181,7 +180,7 @@ class SortedGraph:
value_attr = get_attribute(node, "value")
self._node_arg_infos[node.output[0]] = TensorInfo(
value_attr.data_type,
parse_shape(list(to_numpy_array(value_attr).shape)),
parse_shape(list(value_attr.dims)),
)
else:
input_infos = []

View file

@ -9,8 +9,8 @@ from collections import defaultdict
from typing import Any, List, Tuple
import numpy as np
import torch
from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
def gen_unique_name(prefix: str) -> str:
@ -73,17 +73,24 @@ def get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) ->
return default_value
# Convert Constant node or TensorProto to numpy array.
def to_numpy_array(node: Any) -> np.ndarray:
# Convert Constant node or TensorProto to torch.Tensor.
def to_torch_tensor(node: Any) -> torch.Tensor:
tensor = node
if isinstance(node, NodeProto):
tensor = get_attribute(node, "value")
assert isinstance(tensor, TensorProto)
return numpy_helper.to_array(tensor)
torch_tensor = torch.from_numpy(numpy_helper.to_array(tensor))
# numpy does not support bfloat16 and create a float32 tensor instead.
if tensor.data_type == TensorProto.BFLOAT16:
torch_tensor = torch_tensor.to(torch.bfloat16)
return torch_tensor
def to_numpy_type(tensor_type: TensorProto.DataType) -> np.dtype:
return TENSOR_TYPE_TO_NP_TYPE[tensor_type] if not isinstance(tensor_type, np.dtype) else tensor_type
def to_torch_dtype(tensor_type: TensorProto.DataType) -> torch.dtype:
# Native numpy does not support bfloat16.
if tensor_type == TensorProto.BFLOAT16:
return torch.bfloat16
return torch.from_numpy(np.zeros(1, dtype=helper.tensor_dtype_to_np_dtype(tensor_type))).dtype
# Generate a unique variable name based on the node arg name.
@ -133,7 +140,7 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl
axes_initializer = initializer
break
assert axes_initializer is not None
axes = to_numpy_array(axes_initializer).tolist()
axes = to_torch_tensor(axes_initializer).tolist()
if axes is None:
axes = list(range(input_rank)) if noop_with_empty_axes == 0 else []
axes = sort_reduce_axes(axes, input_rank, check_contiguous=False)

View file

@ -13,7 +13,7 @@ from onnx import TensorProto, helper
from onnxruntime.training.ortmodule import register_graph_optimizer
from .._utils import get_attribute, to_numpy_array
from .._utils import get_attribute, to_torch_tensor
@triton.jit
@ -212,7 +212,7 @@ def _get_constant(graph, arg):
initializer = init
if initializer is None:
return None
return to_numpy_array(initializer)
return to_torch_tensor(initializer).tolist()
def _check_slice(graph, node, start, end, axis, step):
@ -224,9 +224,9 @@ def _check_slice(graph, node, start, end, axis, step):
axis += rank
for idx, value in enumerate([start, end, axis, step]):
constant = _get_constant(graph, node.input[idx + 1])
if constant is None or constant.size != 1:
if constant is None or len(constant) != 1:
return False
constant_value = constant.item()
constant_value = constant[0]
if idx == 2 and constant_value < 0:
constant_value += rank
if constant_value != value:

View file

@ -12,7 +12,6 @@ import onnx
import pytest
import torch
from onnx import TensorProto, helper
from packaging.version import Version
from torch._C import _from_dlpack
from torch.utils.dlpack import to_dlpack
@ -843,32 +842,6 @@ def test_slice_scel_module(dtype, has_sum):
_run_module_test(NeuralNetSliceScel, dtype, _gen_inputs, 2)
@pytest.mark.skipif(
Version(torch.__version__) < Version("2.1"), reason="PyTorch has scaled_dot_product_attention since 2.1."
)
def test_scaled_dot_product_attention_module():
class NeuralNetScaledDotProductAttention(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16)
self.linear2 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16)
self.linear3 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16)
def forward(self, q, k, v):
return torch.nn.functional.scaled_dot_product_attention(
self.linear1(q), self.linear2(k), self.linear3(v)
).to(torch.float16)
def _gen_inputs(dtype):
return [
(torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE),
(torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE),
(torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE),
]
_run_module_test(NeuralNetScaledDotProductAttention, torch.float16, _gen_inputs, 3)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("input_shapes", [([128, 64], [64, 64]), ([16, 64, 128], [16, 128, 64])])
def test_matmul_tunable_op(dtype, input_shapes):