From c6aa03bd4e9fe5b223eef2e97fda0bc61d593b1f Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 12 Jul 2024 15:49:38 +0000 Subject: [PATCH] Add allow_xpu to enable XPU UTs (#130312) # Motivation enable UTs under folder test/xpu/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/130312 Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD --- aten/src/ATen/native/mkldnn/xpu/Blas.cpp | 8 ++++++++ test/xpu/test_conv.py | 4 +++- test/xpu/test_gemm.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index 78882d5737f..518ce8a4f1d 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -29,6 +29,10 @@ Tensor& addmm_out( "x", mat2.sizes()[1], ")"); + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ) std::vector result_shape = {mat1.size(0), mat2.size(1)}; result.resize_(result_shape); @@ -131,6 +135,10 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) { "x", mat2.sizes()[1], ")"); + TORCH_CHECK( + self.dtype() == mat2.dtype(), + "expected self and mat2 to have the same dtype, but got: ", self.dtype(), " != ", mat2.dtype() + ) result.resize_({self.size(0), mat2.size(1)}); if (self.numel() == 0 || mat2.numel() == 0) { diff --git a/test/xpu/test_conv.py b/test/xpu/test_conv.py index f3d4375213f..a6bbefcff2d 100644 --- a/test/xpu/test_conv.py +++ b/test/xpu/test_conv.py @@ -1264,7 +1264,9 @@ class TestConvolutionNNDeviceType(NNTestCase): assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512)) -instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), only_for="xpu") +instantiate_device_type_tests( + TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True +) if __name__ == "__main__": run_tests() diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index 0157677a582..2bc6d09eeea 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -1142,7 +1142,7 @@ class TestBasicGEMM(TestCase): torch.matmul(a, b, out=c) -instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu") +instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) if __name__ == "__main__": run_tests()