mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add allow_xpu to enable XPU UTs (#130312)
# Motivation enable UTs under folder test/xpu/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/130312 Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
parent
fc238db62a
commit
c6aa03bd4e
3 changed files with 12 additions and 2 deletions
|
|
@ -29,6 +29,10 @@ Tensor& addmm_out(
|
|||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
mat1.dtype() == mat2.dtype(),
|
||||
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
|
||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||
result.resize_(result_shape);
|
||||
|
|
@ -131,6 +135,10 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
|||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
self.dtype() == mat2.dtype(),
|
||||
"expected self and mat2 to have the same dtype, but got: ", self.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
|
||||
result.resize_({self.size(0), mat2.size(1)});
|
||||
if (self.numel() == 0 || mat2.numel() == 0) {
|
||||
|
|
|
|||
|
|
@ -1264,7 +1264,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
|||
assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), only_for="xpu")
|
||||
instantiate_device_type_tests(
|
||||
TestConvolutionNNDeviceType, globals(), only_for="xpu", allow_xpu=True
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1142,7 +1142,7 @@ class TestBasicGEMM(TestCase):
|
|||
torch.matmul(a, b, out=c)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu")
|
||||
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in a new issue