onnxruntime/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
Justin Chu c7c8757a1c
Use ruff as the formatter to replace black-isort (#23397)
Use ruff as the code formatter in place of black and isort since it is
much faster, and as projects like PyTorch and ONNX have adopted ruff
format as well.

This PR include only auto-fixed changes in formatting.
2025-01-16 11:14:15 -08:00

370 lines
14 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# pylint: disable=C0116,W0212,R1720,C0103,C0114
#
# Note: the precision is different on V100, H100 even with the same code.
# The thresholds were adjusted on H100 as the precision seems lower on this machine.
import itertools
import unittest
import warnings
import numpy as np
import parameterized
from numpy.testing import assert_allclose
from onnx import TensorProto
from onnx.checker import check_model
from onnx.defs import onnx_opset_version
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info
from onnx.numpy_helper import from_array
from onnxruntime import InferenceSession, get_available_providers
available_providers = [provider for provider in get_available_providers()]
class TestFloat8Gemm8(unittest.TestCase):
def get_model_gemm(
self,
a_float_name="FLOAT",
b_float_name="FLOAT",
c_float_name="FLOAT",
alpha=1.0,
beta=0.0,
transA=0,
transB=0,
scaleA=True,
scaleB=True,
scaleY=True,
domain="",
dtype=TensorProto.FLOAT,
activation="NONE",
):
a_proto_type = getattr(TensorProto, a_float_name)
b_proto_type = getattr(TensorProto, b_float_name)
c_proto_type = getattr(TensorProto, c_float_name)
f8_set = {TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E5M2}
use_f8 = len({a_proto_type, b_proto_type, c_proto_type}.intersection(f8_set)) > 0
a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
d = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])
inits = []
kwargs = {}
node_inputs = ["Af", "Bf"]
inputs = [a, b]
bias = beta != 0
if bias:
inputs.append(make_tensor_value_info("C", TensorProto.FLOAT, [None, None]))
node_inputs = ["Af", "Bf", "Cf"]
if use_f8:
node_inputs.append("one" if scaleA else "")
node_inputs.append("one" if scaleB else "")
node_inputs.append("one" if scaleY else "")
elif use_f8:
node_inputs.append("")
node_inputs.append("one" if scaleA else "")
node_inputs.append("one" if scaleB else "")
node_inputs.append("one" if scaleY else "")
if use_f8:
assert domain == "com.microsoft"
inits.append(from_array(np.array([1], dtype=np.float32), name="one"))
kwargs = dict(
domain=domain,
dtype=dtype,
)
if activation is not None:
kwargs["activation"] = activation
op_name = "GemmFloat8"
elif domain == "com.microsoft":
op_name = "GemmFloat8"
kwargs = dict(
domain=domain,
dtype=dtype,
)
else:
op_name = "Gemm"
nodes = [
make_node("Cast", ["A"], ["Af"], to=a_proto_type),
make_node("Cast", ["B"], ["Bf"], to=b_proto_type),
make_node("Cast", ["C"], ["Cf"], to=c_proto_type) if bias else None,
make_node(
op_name,
node_inputs,
["Yf"],
transA=transA,
transB=transB,
alpha=alpha,
beta=beta,
**kwargs,
),
make_node("Cast", ["Yf"], ["Y"], to=TensorProto.FLOAT),
]
nodes = [n for n in nodes if n is not None]
graph = make_graph(nodes, "gemm", inputs, [d], inits)
opset_imports = [make_opsetid("", onnx_opset_version() - 1)]
if domain == "com.microsoft":
opset_imports.append(make_opsetid("com.microsoft", 1))
onnx_model = make_model(graph, opset_imports=opset_imports, ir_version=9)
if domain != "com.microsoft":
check_model(onnx_model)
return onnx_model
def common_test_model_gemm(
self,
a_float_name="FLOAT",
b_float_name="FLOAT",
c_float_name="FLOAT",
mul=0.33,
atol=0,
rtol=0,
square=True,
**kwargs,
):
if square:
a = (np.arange(256) * 0.01).astype(np.float32).reshape((-1, 16))
b = (np.arange(256) * -0.01).astype(np.float32).reshape((-1, 16))
c = (np.arange(256) * 0.03).astype(np.float32).reshape((-1, 16))
b[:, 0] += 1
else:
a = (np.arange(256) / 256).astype(np.float32).reshape((32, -1))
b = (np.arange(512) / 512).astype(np.float32).reshape((32, -1))
c = (np.arange(128) / 128).astype(np.float32).reshape((8, 16))
feeds = {"A": a, "B": b}
providers = ["CPUExecutionProvider"]
if "CUDAExecutionProvider" in available_providers:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
elif "ROCMExecutionProvider" in available_providers:
providers = [
("ROCMExecutionProvider", {"tunable_op_enable": "1", "tunable_op_tuning_enable": "1"}),
("CPUExecutionProvider", {}),
]
expected = (a.T if kwargs.get("transA", 0) else a) @ (b.T if kwargs.get("transB", 0) else b)
expected *= kwargs.get("alpha", 1.0)
if kwargs.get("beta", 0) != 0:
expected += kwargs["beta"] * c
feeds["C"] = c
onnx_model = self.get_model_gemm(**kwargs)
ref = InferenceSession(onnx_model.SerializeToString(), providers=providers)
y = ref.run(None, feeds)[0]
if (
"CUDAExecutionProvider" in providers
and a_float_name in ("FLOAT", "FLOAT16")
and b_float_name in ("FLOAT", "FLOAT16")
and c_float_name in ("FLOAT", "FLOAT16")
):
try:
assert_allclose(expected, y, atol=atol, rtol=rtol)
except Exception as e:
def check(f):
try:
return f()[:2, :2]
except Exception as e:
return str(e)
raise AssertionError(
f"Gemm ERROR len(inputs)={len(feeds)}"
f"\na@b=\n{check(lambda: a @ b)}"
f"\na.T@b=\n{check(lambda: a.T @ b)}"
f"\na@b.T=\n{check(lambda: a @ b.T)}"
f"\na.T@b.T=\n{check(lambda: a.T @ b.T)}"
f"\n----\nb@a=\n{check(lambda: b @ a)}"
f"\nb.T@a=\n{check(lambda: b.T @ a)}"
f"\nb@a.T=\n{check(lambda: b @ a.T)}"
f"\nb.T@a.T=\n{check(lambda: b.T @ a.T)}"
f"\n----\nexpected=\n{expected[:2, :2]}"
f"\n----\ngot=\n{y[:2, :2]}"
f"\nkwargs={kwargs}"
) from e
self.assertEqual(expected.shape, y.shape)
self.assertEqual(expected.dtype, y.dtype)
onnx_model_f8 = self.get_model_gemm(
a_float_name=a_float_name,
b_float_name=b_float_name,
c_float_name=c_float_name,
domain="com.microsoft",
**kwargs,
)
try:
ref8 = InferenceSession(onnx_model_f8.SerializeToString(), providers=providers)
except Exception as e:
if "CUDA < 12.0 does not support bias" in str(e):
return
raise AssertionError(f"Could not load model {onnx_model_f8}") from e
try:
y = ref8.run(None, feeds)[0]
except Exception as e:
if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e):
# Skipping. This machine does not support float8.
warnings.warn("unable to test with float8 on this machine.")
return
if "CK is required to support GemmFloat8 computing" in str(e):
warnings.warn("unable to test with float8 on this build.")
return
raise AssertionError(f"Could not execute model {onnx_model_f8}") from e
try:
assert_allclose(expected, y, atol=atol, rtol=rtol)
except Exception as e:
def check(f):
try:
return f()[:2, :2]
except Exception as e:
return str(e)
raise AssertionError(
f"Gemm ERROR len(inputs)={len(feeds)}"
f"\na@b=\n{check(lambda: a @ b)}"
f"\na.T@b=\n{check(lambda: a.T @ b)}"
f"\na@b.T=\n{check(lambda: a @ b.T)}"
f"\na.T@b.T=\n{check(lambda: a.T @ b.T)}"
f"\n----\nb@a=\n{check(lambda: b @ a)}"
f"\nb.T@a=\n{check(lambda: b.T @ a)}"
f"\nb@a.T=\n{check(lambda: b @ a.T)}"
f"\nb.T@a.T=\n{check(lambda: b.T @ a.T)}"
f"\n----\nexpected=\n{expected[:2, :2]}"
f"\n----\ngot=\n{y[:2, :2]}"
f"\nkwargs={kwargs}"
) from e
self.assertEqual(expected.shape, y.shape)
self.assertEqual(expected.dtype, y.dtype)
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float(self):
self.common_test_model_gemm(transA=1, rtol=1e-3)
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_default_values(self):
self.common_test_model_gemm(transA=1, rtol=1e-3, activation=None)
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_relu(self):
self.common_test_model_gemm(transA=1, rtol=1e-3, activation="RELU")
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_gelu(self):
self.common_test_model_gemm(transA=1, rtol=1e-3, activation="GELU")
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float_bias(self):
self.common_test_model_gemm(transA=1, beta=1.0, rtol=1e-3)
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_model_gemm_float16(self):
self.common_test_model_gemm(
a_float_name="FLOAT16",
b_float_name="FLOAT16",
c_float_name="FLOAT16",
rtol=1e-2,
dtype=TensorProto.FLOAT16,
transB=1,
)
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
@unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0")
def test_model_gemm_float8_e4m3(self):
self.common_test_model_gemm(
a_float_name="FLOAT8E4M3FN",
b_float_name="FLOAT8E4M3FN",
c_float_name="FLOAT8E4M3FN",
rtol=0.5,
dtype=TensorProto.FLOAT,
transA=0,
transB=1,
alpha=10.0,
)
@parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1])))
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_combinations_square_matrices(self, transA, transB):
self.common_test_model_gemm(transA=transA, transB=transB, rtol=1e-3)
@parameterized.parameterized.expand(
[
((2, 3), (3, 5), 0, 0),
((2, 3), (5, 3), 0, 1),
((2, 3), (5, 2), 1, 1),
((2, 3), (2, 5), 1, 0),
]
)
@unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.")
def test_combinations(self, shapeA, shapeB, transA, transB):
model = make_model(
make_graph(
[
make_node(
"GemmFloat8",
["A", "B"],
["Y"],
transA=transA,
transB=transB,
domain="com.microsoft",
)
],
"f8",
[
make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
],
[make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])],
),
opset_imports=[make_opsetid("", 19), make_opsetid("com.microsoft", 1)],
)
sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
a = np.arange(np.prod(shapeA)).reshape(shapeA).astype(np.float32)
b = np.arange(np.prod(shapeB)).reshape(shapeB).astype(np.float32)
try:
expected = (a.T if transA else a) @ (b.T if transB else b)
except Exception as e:
raise AssertionError(
f"Unable to multiply shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}"
) from e
try:
got = sess.run(None, {"A": a, "B": b})
except Exception as e:
raise AssertionError(
f"Unable to run Gemm with shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}"
) from e
self.assertEqual(expected.shape, got[0].shape)
self.assertEqual(expected.dtype, got[0].dtype)
assert_allclose(expected, got[0])
@parameterized.parameterized.expand(
[
("FLOAT8E4M3FN", "FLOAT16", 0, 0),
("FLOAT16", "FLOAT8E4M3FN", 0, 0),
("FLOAT16", "FLOAT8E4M3FN", 0, 1),
]
)
@unittest.skipIf("ROCMExecutionProvider" not in available_providers, reason="Not running without ROCm.")
@unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0")
def test_model_rocm_gemm_float8_e4m3(self, a_float_name, b_float_name, transA, transB):
self.common_test_model_gemm(
a_float_name=a_float_name,
b_float_name=b_float_name,
c_float_name="FLOAT8E4M3FN",
rtol=0.5,
dtype=TensorProto.FLOAT16,
transA=0,
transB=transB,
scaleY=False,
alpha=10.0,
beta=0.0,
)
if __name__ == "__main__":
# TestFloat8Gemm8().test_model_gemm_float()
unittest.main(verbosity=2)