mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add CPU inductor support for _scaled_mm
ghstack-source-id: 4cd3e37881
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141961
This commit is contained in:
parent
489513f482
commit
a080c435b8
3 changed files with 32 additions and 22 deletions
|
|
@ -46,7 +46,7 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.custom_tensor import CustomTensorPlainOut
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
|
||||
from torch.utils import _pytree as pytree
|
||||
|
|
@ -736,16 +736,13 @@ class AOTInductorTestsTemplate:
|
|||
example_inputs = (x, y)
|
||||
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
def test_fp8(self):
|
||||
# cuda only
|
||||
if self.device != "cuda":
|
||||
return
|
||||
if self.device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, dtype):
|
||||
|
|
@ -766,16 +763,18 @@ class AOTInductorTestsTemplate:
|
|||
|
||||
dtype = torch.float16
|
||||
|
||||
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
||||
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
||||
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
|
||||
a_scale = torch.Tensor([1.0]).to(device=self.device)
|
||||
b_scale = torch.Tensor([1.0]).to(device=self.device)
|
||||
input_bias = torch.rand(32, device=self.device, dtype=dtype)
|
||||
weight_shape = (32, 16)
|
||||
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
|
||||
weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).T
|
||||
a_inverse_scale = 1 / a_scale
|
||||
b_inverse_scale = 1 / b_scale
|
||||
|
||||
x_shape = (16, 16)
|
||||
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
|
||||
x = torch.rand(*x_shape, device=self.device, dtype=dtype).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
dim0_x = Dim("dim0_x", min=1, max=2048)
|
||||
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
|
||||
self.check_model(
|
||||
|
|
@ -784,16 +783,13 @@ class AOTInductorTestsTemplate:
|
|||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
def test_fp8_view_of_param(self):
|
||||
# cuda only
|
||||
if self.device != GPU_TYPE:
|
||||
return
|
||||
if self.device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(
|
||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, dtype, weight):
|
||||
|
|
@ -4521,7 +4517,7 @@ copy_tests(
|
|||
)
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
||||
@unittest.skipIf(sys.platform == "darwin" or not HAS_GPU, "No CUDA on MacOS")
|
||||
class AOTInductorTestABICompatibleGpu(TestCase):
|
||||
device = GPU_TYPE
|
||||
device_type = GPU_TYPE
|
||||
|
|
@ -4543,5 +4539,5 @@ if __name__ == "__main__":
|
|||
from torch._inductor.test_case import run_tests
|
||||
|
||||
# cpp_extension N/A in fbcode
|
||||
if HAS_GPU or sys.platform == "darwin":
|
||||
if HAS_GPU or sys.platform == "darwin" or HAS_CPU:
|
||||
run_tests(needs="filelock")
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@
|
|||
// The following files are implemented in a header-only way and are guarded by
|
||||
// test/cpp/aoti_abi_check
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
|
|
@ -173,6 +175,12 @@ aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16* ret_value);
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64(
|
||||
AtenTensorHandle tensor,
|
||||
c10::complex<float>* ret_value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e4m3fn(
|
||||
AtenTensorHandle tensor,
|
||||
c10::Float8_e4m3fn* ret_value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e5m2(
|
||||
AtenTensorHandle tensor,
|
||||
c10::Float8_e5m2* ret_value);
|
||||
|
||||
// Functions for wrapping a scalar value to a single-element tensor
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32(
|
||||
|
|
@ -719,6 +727,8 @@ int32_t aoti_torch_dtype() = delete;
|
|||
namespace c10 {
|
||||
struct BFloat16;
|
||||
struct Half;
|
||||
struct Float8_e4m3fn;
|
||||
struct Float8_e5m2;
|
||||
} // namespace c10
|
||||
|
||||
DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16)
|
||||
|
|
@ -732,6 +742,8 @@ DEFINE_DTYPE_SPECIALIZATION(int16_t, int16)
|
|||
DEFINE_DTYPE_SPECIALIZATION(int32_t, int32)
|
||||
DEFINE_DTYPE_SPECIALIZATION(int64_t, int64)
|
||||
DEFINE_DTYPE_SPECIALIZATION(bool, bool)
|
||||
DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e4m3fn, float8_e4m3fn)
|
||||
DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e5m2, float8_e5m2)
|
||||
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -201,6 +201,8 @@ AOTI_TORCH_ITEM_IMPL(int64, int64_t)
|
|||
AOTI_TORCH_ITEM_IMPL(bool, bool)
|
||||
AOTI_TORCH_ITEM_IMPL(bfloat16, c10::BFloat16)
|
||||
AOTI_TORCH_ITEM_IMPL(complex64, c10::complex<float>)
|
||||
AOTI_TORCH_ITEM_IMPL(float8_e4m3fn, c10::Float8_e4m3fn)
|
||||
AOTI_TORCH_ITEM_IMPL(float8_e5m2, c10::Float8_e5m2)
|
||||
#undef AOTI_TORCH_ITEM_IMPL
|
||||
|
||||
#define AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(dtype, ctype, ttype) \
|
||||
|
|
|
|||
Loading…
Reference in a new issue