diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index f1c13adb64b..84aa47574de 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -46,7 +46,7 @@ from torch.testing._internal.common_utils import ( TEST_WITH_ROCM, ) from torch.testing._internal.custom_tensor import CustomTensorPlainOut -from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu from torch.utils import _pytree as pytree @@ -736,16 +736,13 @@ class AOTInductorTestsTemplate: example_inputs = (x, y) self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", - ) @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform @skipIfXpu def test_fp8(self): - # cuda only - if self.device != "cuda": - return + if self.device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest( + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" + ) class Model(torch.nn.Module): def __init__(self, dtype): @@ -766,16 +763,18 @@ class AOTInductorTestsTemplate: dtype = torch.float16 - a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE) - b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE) - input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype) + a_scale = torch.Tensor([1.0]).to(device=self.device) + b_scale = torch.Tensor([1.0]).to(device=self.device) + input_bias = torch.rand(32, device=self.device, dtype=dtype) weight_shape = (32, 16) - weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T + weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).T a_inverse_scale = 1 / a_scale b_inverse_scale = 1 / b_scale x_shape = (16, 16) - x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn) + x = torch.rand(*x_shape, device=self.device, dtype=dtype).to( + torch.float8_e4m3fn + ) dim0_x = Dim("dim0_x", min=1, max=2048) dynamic_shapes = ({0: dim0_x}, None, None, None, None) self.check_model( @@ -784,16 +783,13 @@ class AOTInductorTestsTemplate: dynamic_shapes=dynamic_shapes, ) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", - ) @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform @skipIfXpu def test_fp8_view_of_param(self): - # cuda only - if self.device != GPU_TYPE: - return + if self.device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest( + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" + ) class Model(torch.nn.Module): def __init__(self, dtype, weight): @@ -4521,7 +4517,7 @@ copy_tests( ) -@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") +@unittest.skipIf(sys.platform == "darwin" or not HAS_GPU, "No CUDA on MacOS") class AOTInductorTestABICompatibleGpu(TestCase): device = GPU_TYPE device_type = GPU_TYPE @@ -4543,5 +4539,5 @@ if __name__ == "__main__": from torch._inductor.test_case import run_tests # cpp_extension N/A in fbcode - if HAS_GPU or sys.platform == "darwin": + if HAS_GPU or sys.platform == "darwin" or HAS_CPU: run_tests(needs="filelock") diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index f22ad01ce33..77d342635ca 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -62,6 +62,8 @@ // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check #include +#include +#include #include #include @@ -173,6 +175,12 @@ aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16* ret_value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64( AtenTensorHandle tensor, c10::complex* ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e4m3fn( + AtenTensorHandle tensor, + c10::Float8_e4m3fn* ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_float8_e5m2( + AtenTensorHandle tensor, + c10::Float8_e5m2* ret_value); // Functions for wrapping a scalar value to a single-element tensor AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32( @@ -719,6 +727,8 @@ int32_t aoti_torch_dtype() = delete; namespace c10 { struct BFloat16; struct Half; +struct Float8_e4m3fn; +struct Float8_e5m2; } // namespace c10 DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) @@ -732,6 +742,8 @@ DEFINE_DTYPE_SPECIALIZATION(int16_t, int16) DEFINE_DTYPE_SPECIALIZATION(int32_t, int32) DEFINE_DTYPE_SPECIALIZATION(int64_t, int64) DEFINE_DTYPE_SPECIALIZATION(bool, bool) +DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e4m3fn, float8_e4m3fn) +DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e5m2, float8_e5m2) #endif diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index b0cebc6778a..5f67bb08536 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -201,6 +201,8 @@ AOTI_TORCH_ITEM_IMPL(int64, int64_t) AOTI_TORCH_ITEM_IMPL(bool, bool) AOTI_TORCH_ITEM_IMPL(bfloat16, c10::BFloat16) AOTI_TORCH_ITEM_IMPL(complex64, c10::complex) +AOTI_TORCH_ITEM_IMPL(float8_e4m3fn, c10::Float8_e4m3fn) +AOTI_TORCH_ITEM_IMPL(float8_e5m2, c10::Float8_e5m2) #undef AOTI_TORCH_ITEM_IMPL #define AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(dtype, ctype, ttype) \