mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Support BFloat16 for Triton Codegen (#20353)
Previous implementation used numpy array and numpy data_type to store constant value and data type, which is not support BFloat16 natively. This PR is to switch to use torch tensor which supports BFloat16.
This commit is contained in:
parent
da86f6f408
commit
c47f446f25
8 changed files with 48 additions and 72 deletions
|
|
@ -14,7 +14,6 @@ Mostly, Nodes are classified into two categories:
|
|||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
from sympy.codegen.rewriting import create_expand_pow_optimization
|
||||
|
|
@ -316,14 +315,14 @@ class TritonCodegen(NodeVisitor):
|
|||
op_type = "Sqrt"
|
||||
|
||||
if op_type == "Cast":
|
||||
from_dtype = node.inputs[0].dtype.type
|
||||
to_dtype = node.outputs[0].dtype.type
|
||||
from_dtype = node.inputs[0].dtype
|
||||
to_dtype = node.outputs[0].dtype
|
||||
if from_dtype == to_dtype or is_number(kwargs["i0"]):
|
||||
op_type = "Identity"
|
||||
elif to_dtype == np.bool_:
|
||||
elif to_dtype == torch.bool:
|
||||
op_type = "CastBool"
|
||||
else:
|
||||
kwargs["dtype"] = to_dtype.__name__
|
||||
kwargs["dtype"] = str(to_dtype)[6:] # Remove "torch." prefix.
|
||||
|
||||
if op_type == "QuickGelu" or op_type == "QuickGeluGrad":
|
||||
kwargs["alpha"] = str(node.attributes.get("alpha", 1.702))
|
||||
|
|
@ -473,7 +472,7 @@ class TritonCodegen(NodeVisitor):
|
|||
code_buffer += "\n"
|
||||
# Allocate output tensor.
|
||||
for output in kernel_node.outputs:
|
||||
torch_dtype = torch.from_numpy(np.zeros(1, dtype=output.dtype)).dtype
|
||||
torch_dtype = output.dtype
|
||||
# Workaround for DLPack which doesn't support bool.
|
||||
if torch_dtype == torch.bool:
|
||||
torch_dtype = torch.uint8
|
||||
|
|
|
|||
|
|
@ -10,11 +10,10 @@ Decompose a complicated op into a series of simple ops.
|
|||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import sympy
|
||||
from onnx import GraphProto, NodeProto, TensorProto, helper
|
||||
|
||||
from ._utils import get_attribute, get_reduce_info, to_numpy_type
|
||||
from ._utils import get_attribute, get_reduce_info
|
||||
|
||||
|
||||
def _is_half_dtype(dtype: int):
|
||||
|
|
@ -132,7 +131,7 @@ class DecomposeDispatch:
|
|||
if axis < 0:
|
||||
axis += rank
|
||||
axes = list(range(axis, rank))
|
||||
epsilon_tensor = helper.make_tensor(name="epsilon_const", data_type=xdtype, dims=(1,), vals=np.array([epsilon]))
|
||||
epsilon_tensor = helper.make_tensor(name="epsilon_const", data_type=xdtype, dims=(1,), vals=[epsilon])
|
||||
const_node, const_out = self._new_node(node_name, "Constant", [], value=epsilon_tensor)
|
||||
reducemean_node, reducemean_out = self._new_node(node_name, "ReduceMean", [x], outputs=[mean], axes=axes)
|
||||
sub_node, sub_out = self._new_node(node_name, "Sub", [x, reducemean_out])
|
||||
|
|
@ -371,7 +370,7 @@ class DecomposeDispatch:
|
|||
name=f"{node_name}_denominator",
|
||||
dims=(),
|
||||
data_type=dtype,
|
||||
vals=np.array([denominator], dtype=to_numpy_type(dtype)),
|
||||
vals=[denominator],
|
||||
)
|
||||
denominator_node, denominator_out = self._new_node(node_name, "Constant", [], value=denominator_tensor)
|
||||
div_node, _ = self._new_node(node_name, "Div", [sum_out, denominator_out], outputs=[y])
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ from abc import abstractmethod
|
|||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
|
||||
from ._common import AutotuneConfigs, CodeBuffer, CodegenContext, NodeVisitor, TensorInfo
|
||||
from ._sympy_utils import parse_shape
|
||||
from ._utils import gen_unique_name, gen_variable_name, sort_reduce_axes, to_numpy_type
|
||||
from ._utils import gen_unique_name, gen_variable_name, sort_reduce_axes, to_torch_dtype
|
||||
|
||||
|
||||
class TensorArg:
|
||||
|
|
@ -22,15 +22,15 @@ class TensorArg:
|
|||
If it's constant (initializer or constant node), it also contains the data in numpy array.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, tensor_info: Optional[TensorInfo] = None, data: Optional[np.ndarray] = None):
|
||||
def __init__(self, name: str, tensor_info: Optional[TensorInfo] = None, data: Optional[torch.Tensor] = None):
|
||||
self._name: str = name
|
||||
self._data: Optional[np.ndarray] = data
|
||||
self._data: Optional[torch.Tensor] = data
|
||||
if data is not None:
|
||||
self._dtype: np.dtype = data.dtype
|
||||
self._dtype: torch.dtype = data.dtype
|
||||
self._shape: List[sympy.Expr] = parse_shape(list(data.shape))
|
||||
else:
|
||||
assert tensor_info is not None
|
||||
self._dtype: np.dtype = to_numpy_type(tensor_info.dtype)
|
||||
self._dtype: torch.dtype = to_torch_dtype(tensor_info.dtype)
|
||||
self._shape: List[sympy.Expr] = tensor_info.shape
|
||||
self.cross_kernels: bool = False
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ class TensorArg:
|
|||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
|
|
@ -47,7 +47,7 @@ class TensorArg:
|
|||
return self._shape
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[np.ndarray]:
|
||||
def data(self) -> Optional[torch.Tensor]:
|
||||
return self._data
|
||||
|
||||
|
||||
|
|
@ -328,11 +328,10 @@ class KernelNode(IRNode):
|
|||
self.var_map[name] = gen_variable_name(name, "c", existing_names)
|
||||
if tensor_arg.data is not None:
|
||||
value = tensor_arg.data
|
||||
if value is not None:
|
||||
assert value.size == 1, f"unsupported constant array {value}"
|
||||
variable_name = self.var_map[name]
|
||||
assert variable_name not in self.var_map
|
||||
self.var_map[variable_name] = str(np.array(value.item(), value.dtype))
|
||||
assert value.numel() == 1, f"unsupported constant {value} which has more than one element."
|
||||
variable_name = self.var_map[name]
|
||||
assert variable_name not in self.var_map
|
||||
self.var_map[variable_name] = str(value.item())
|
||||
|
||||
|
||||
class ElementwiseKernelNode(KernelNode):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from ._ir import (
|
|||
)
|
||||
from ._op_config import is_reduction_node
|
||||
from ._sorted_graph import SortedGraph
|
||||
from ._utils import get_reduce_info, sort_reduce_axes, to_numpy_array
|
||||
from ._utils import get_reduce_info, sort_reduce_axes, to_torch_tensor
|
||||
|
||||
|
||||
class NodeGroup:
|
||||
|
|
@ -245,10 +245,10 @@ class GraphLowering:
|
|||
self._module_outputs = [TensorArg(output.name, self._node_arg_infos[output.name]) for output in graph.output]
|
||||
self._module_output_names = set(arg.name for arg in self._module_outputs)
|
||||
for initializer in graph.initializer:
|
||||
data = to_numpy_array(initializer)
|
||||
data = to_torch_tensor(initializer)
|
||||
self._module_constants.append(TensorArg(initializer.name, data=data))
|
||||
for const_node in self._sorted_graph.const_nodes:
|
||||
data = to_numpy_array(const_node)
|
||||
data = to_torch_tensor(const_node)
|
||||
self._module_constants.append(TensorArg(const_node.output[0], data=data))
|
||||
self._module_constant_names = set(arg.name for arg in self._module_constants)
|
||||
self._tensor_args = dict(
|
||||
|
|
@ -415,7 +415,7 @@ class GraphLowering:
|
|||
for idx in range(cur, nxt):
|
||||
for input in sub_nodes[idx].inputs:
|
||||
if input.name in kernel_node.constants or input.name in input_names:
|
||||
if (input.data is not None and input.data.size == 1) or input.name in load_cache:
|
||||
if (input.data is not None and input.data.numel() == 1) or input.name in load_cache:
|
||||
continue
|
||||
load_nodes.append(IONode(input, kernel_node.offset_calc, True))
|
||||
load_cache.add(input.name)
|
||||
|
|
@ -432,7 +432,7 @@ class GraphLowering:
|
|||
for reduce_node in sub_nodes[nxt].reduce_nodes:
|
||||
input = reduce_node.inputs[0]
|
||||
if input.name in kernel_node.constants or input.name in input_names:
|
||||
if (input.data is not None and input.data.size == 1) or input.name in load_cache:
|
||||
if (input.data is not None and input.data.numel() == 1) or input.name in load_cache:
|
||||
continue
|
||||
load_nodes.append(IONode(input, kernel_node.offset_calc, True))
|
||||
load_cache.add(input.name)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import copy
|
|||
import itertools
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import sympy
|
||||
from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper
|
||||
|
|
@ -16,7 +15,7 @@ from ._common import SymbolicDSU, TensorInfo, TypeAndShapeInfer
|
|||
from ._decompose import DecomposeDispatch
|
||||
from ._op_config import is_elementwise_node
|
||||
from ._sympy_utils import parse_shape
|
||||
from ._utils import get_attribute, to_numpy_array, topological_sort
|
||||
from ._utils import get_attribute, to_torch_tensor, topological_sort
|
||||
|
||||
|
||||
class SortedGraph:
|
||||
|
|
@ -58,7 +57,7 @@ class SortedGraph:
|
|||
for initializer in self._graph.initializer:
|
||||
self._node_arg_infos[initializer.name] = TensorInfo(
|
||||
initializer.data_type,
|
||||
parse_shape(list(to_numpy_array(initializer).shape)),
|
||||
parse_shape(list(initializer.dims)),
|
||||
)
|
||||
|
||||
# Decompose complex operators.
|
||||
|
|
@ -97,13 +96,13 @@ class SortedGraph:
|
|||
|
||||
constants = []
|
||||
for idx, initializer in enumerate(self._sorted_initializers):
|
||||
data_str = np.array2string(to_numpy_array(initializer), separator=",").replace("\n", "").replace(" ", "")
|
||||
data_str = str(to_torch_tensor(initializer).tolist()).replace("\n", "").replace(" ", "")
|
||||
constants.append(f"({initializer.data_type},{data_str})")
|
||||
name_map[initializer.name] = f"c{idx}"
|
||||
|
||||
for idx, node in enumerate(self._const_nodes):
|
||||
value_attr = get_attribute(node, "value")
|
||||
data_str = np.array2string(to_numpy_array(value_attr), separator=",").replace("\n", "").replace(" ", "")
|
||||
data_str = str(to_torch_tensor(value_attr).tolist()).replace("\n", "").replace(" ", "")
|
||||
constants.append(f"({value_attr.data_type},{data_str})")
|
||||
name_map[node.output[0]] = f"c{idx + len(self._sorted_initializers)}"
|
||||
constants_str = ",".join(constants)
|
||||
|
|
@ -181,7 +180,7 @@ class SortedGraph:
|
|||
value_attr = get_attribute(node, "value")
|
||||
self._node_arg_infos[node.output[0]] = TensorInfo(
|
||||
value_attr.data_type,
|
||||
parse_shape(list(to_numpy_array(value_attr).shape)),
|
||||
parse_shape(list(value_attr.dims)),
|
||||
)
|
||||
else:
|
||||
input_infos = []
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from collections import defaultdict
|
|||
from typing import Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
|
||||
|
||||
def gen_unique_name(prefix: str) -> str:
|
||||
|
|
@ -73,17 +73,24 @@ def get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) ->
|
|||
return default_value
|
||||
|
||||
|
||||
# Convert Constant node or TensorProto to numpy array.
|
||||
def to_numpy_array(node: Any) -> np.ndarray:
|
||||
# Convert Constant node or TensorProto to torch.Tensor.
|
||||
def to_torch_tensor(node: Any) -> torch.Tensor:
|
||||
tensor = node
|
||||
if isinstance(node, NodeProto):
|
||||
tensor = get_attribute(node, "value")
|
||||
assert isinstance(tensor, TensorProto)
|
||||
return numpy_helper.to_array(tensor)
|
||||
torch_tensor = torch.from_numpy(numpy_helper.to_array(tensor))
|
||||
# numpy does not support bfloat16 and create a float32 tensor instead.
|
||||
if tensor.data_type == TensorProto.BFLOAT16:
|
||||
torch_tensor = torch_tensor.to(torch.bfloat16)
|
||||
return torch_tensor
|
||||
|
||||
|
||||
def to_numpy_type(tensor_type: TensorProto.DataType) -> np.dtype:
|
||||
return TENSOR_TYPE_TO_NP_TYPE[tensor_type] if not isinstance(tensor_type, np.dtype) else tensor_type
|
||||
def to_torch_dtype(tensor_type: TensorProto.DataType) -> torch.dtype:
|
||||
# Native numpy does not support bfloat16.
|
||||
if tensor_type == TensorProto.BFLOAT16:
|
||||
return torch.bfloat16
|
||||
return torch.from_numpy(np.zeros(1, dtype=helper.tensor_dtype_to_np_dtype(tensor_type))).dtype
|
||||
|
||||
|
||||
# Generate a unique variable name based on the node arg name.
|
||||
|
|
@ -133,7 +140,7 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl
|
|||
axes_initializer = initializer
|
||||
break
|
||||
assert axes_initializer is not None
|
||||
axes = to_numpy_array(axes_initializer).tolist()
|
||||
axes = to_torch_tensor(axes_initializer).tolist()
|
||||
if axes is None:
|
||||
axes = list(range(input_rank)) if noop_with_empty_axes == 0 else []
|
||||
axes = sort_reduce_axes(axes, input_rank, check_contiguous=False)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from onnx import TensorProto, helper
|
|||
|
||||
from onnxruntime.training.ortmodule import register_graph_optimizer
|
||||
|
||||
from .._utils import get_attribute, to_numpy_array
|
||||
from .._utils import get_attribute, to_torch_tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
|
@ -212,7 +212,7 @@ def _get_constant(graph, arg):
|
|||
initializer = init
|
||||
if initializer is None:
|
||||
return None
|
||||
return to_numpy_array(initializer)
|
||||
return to_torch_tensor(initializer).tolist()
|
||||
|
||||
|
||||
def _check_slice(graph, node, start, end, axis, step):
|
||||
|
|
@ -224,9 +224,9 @@ def _check_slice(graph, node, start, end, axis, step):
|
|||
axis += rank
|
||||
for idx, value in enumerate([start, end, axis, step]):
|
||||
constant = _get_constant(graph, node.input[idx + 1])
|
||||
if constant is None or constant.size != 1:
|
||||
if constant is None or len(constant) != 1:
|
||||
return False
|
||||
constant_value = constant.item()
|
||||
constant_value = constant[0]
|
||||
if idx == 2 and constant_value < 0:
|
||||
constant_value += rank
|
||||
if constant_value != value:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import onnx
|
|||
import pytest
|
||||
import torch
|
||||
from onnx import TensorProto, helper
|
||||
from packaging.version import Version
|
||||
from torch._C import _from_dlpack
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
|
||||
|
|
@ -843,32 +842,6 @@ def test_slice_scel_module(dtype, has_sum):
|
|||
_run_module_test(NeuralNetSliceScel, dtype, _gen_inputs, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
Version(torch.__version__) < Version("2.1"), reason="PyTorch has scaled_dot_product_attention since 2.1."
|
||||
)
|
||||
def test_scaled_dot_product_attention_module():
|
||||
class NeuralNetScaledDotProductAttention(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16)
|
||||
self.linear2 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16)
|
||||
self.linear3 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
self.linear1(q), self.linear2(k), self.linear3(v)
|
||||
).to(torch.float16)
|
||||
|
||||
def _gen_inputs(dtype):
|
||||
return [
|
||||
(torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE),
|
||||
(torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE),
|
||||
(torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE),
|
||||
]
|
||||
|
||||
_run_module_test(NeuralNetScaledDotProductAttention, torch.float16, _gen_inputs, 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize("input_shapes", [([128, 64], [64, 64]), ([16, 64, 128], [16, 128, 64])])
|
||||
def test_matmul_tunable_op(dtype, input_shapes):
|
||||
|
|
|
|||
Loading…
Reference in a new issue