From 0dc03134d9eb75b1fb6e49d78744673d291a14df Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Thu, 6 Feb 2025 00:57:49 +0000 Subject: [PATCH] [MPS] linalg solve implementation (#146531) Fixes #98222 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146531 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 7 +- .../native/mps/operations/LinearAlgebra.mm | 188 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 1 + test/test_mps.py | 82 +++++++- 4 files changed, 272 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 8ab08b76a67..a72de4dfd9e 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -590,15 +590,16 @@ TORCH_META_FUNC(_linalg_solve_ex)(const Tensor& A, TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)"); auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1) : B_broad_shape; - auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/left); + // row major for mps implementation + auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/A.device().type() != at::kMPS? left : false); set_output_strided(0, result_shape, result_strides, B.options(), {}); auto shape = A.sizes(); auto ndim = shape.size(); - // LU - auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true); + // LU, row major for mps + auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/A.device().type() != at::kMPS? true : false); set_output_strided(1, shape, LU_strides, A.options(), {}); // pivots diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index dafee5cce62..02ae024049b 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -13,6 +13,7 @@ #include #include #else +#include #include #include #include @@ -241,6 +242,181 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, } } +static void linalg_solve_out_mps_impl(const at::Tensor& A, + const at::Tensor& B, + bool left, + bool check_errors, + const at::Tensor& result, + const at::Tensor& LU, + const at::Tensor& pivots, + const at::Tensor& info) { + using namespace mps; + + TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), + "linalg.lu_factor(): MPS doesn't support complex types."); + Tensor A_t, B_t; + // If 'left' is false, reinterpret the problem so that Ax = B becomes A^T ⋅ (x^T) = B^T + // Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output + if (left) { + A_t = A.contiguous(); + B_t = B.contiguous(); + } else { + A_t = A.transpose(-2, -1).contiguous(); + B_t = B.transpose(-2, -1).contiguous(); + } + + uint64_t aRows = A_t.size(-2); + uint64_t aCols = A_t.size(-1); + uint64_t aElemSize = A_t.element_size(); + int a_ndim = A_t.dim(); + int b_ndim = B_t.dim(); + int numberOfRightHandSides = (b_ndim == a_ndim - 1) ? 1 : (b_ndim >= 2 ? B_t.size(-1) : 1); + + uint64_t numPivots = std::min(aRows, aCols); + std::vector pivot_sizes(A_t.sizes().begin(), A_t.sizes().end() - 2); + info.fill_(0); // will be set to 1 during kernel if something fails + resize_output(info, pivot_sizes); + pivot_sizes.push_back(numPivots); + resize_output(pivots, pivot_sizes); + + if (A_t.numel() == 0) { + return; + } + + if (A_t.dim() > 3) { + A_t = A_t.flatten(0, -3); + } + + uint64_t batchSize = (A_t.dim() > 2) ? A_t.size(0) : 1; + std::vector status_tensors; + std::vector pivots_list; + + status_tensors.reserve(batchSize); + pivots_list.reserve(batchSize); + for ([[maybe_unused]] const auto i : c10::irange(batchSize)) { + status_tensors.push_back(at::zeros(1, kInt, std::nullopt, kMPS, std::nullopt)); + pivots_list.push_back(at::zeros(numPivots, kInt, std::nullopt, kMPS, std::nullopt)); + } + + resize_output(LU, A_t.sizes()); + Tensor LU_ = LU; + if (!LU_.is_same(A_t)) { + A_t = LU_.copy_(A_t); + } else { + A_t = LU_; + } + + TORCH_INTERNAL_ASSERT(A_t.is_contiguous()); + + Tensor result_t; + if (!left) { + // For right solve, we'll need to transpose the result back later + result_t = at::empty_like(B_t, B_t.options()); + } else { + result_t = result; + } + id luBuffer = getMTLBufferStorage(LU_); + id bBuffer = getMTLBufferStorage(B_t); + id resultBuffer = getMTLBufferStorage(result_t); + + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id commandBuffer = mpsStream->commandBuffer(); + + MPSMatrixDecompositionLU* lu_decomp = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device + rows:aRows + columns:aCols] autorelease]; + + MPSMatrixSolveLU* solver = [[[MPSMatrixSolveLU alloc] initWithDevice:device + transpose:false + order:aRows + numberOfRightHandSides:numberOfRightHandSides] autorelease]; + + MPSMatrixDescriptor* luMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows + columns:aCols + matrices:1 + rowBytes:aCols * aElemSize + matrixBytes:aRows * aCols * aElemSize + dataType:getMPSDataType(LU_)]; + MPSMatrixDescriptor* rhsMatrixDesc = + [MPSMatrixDescriptor matrixDescriptorWithRows:aRows + columns:numberOfRightHandSides + matrices:1 + rowBytes:numberOfRightHandSides * aElemSize + matrixBytes:aRows * numberOfRightHandSides * aElemSize + dataType:getMPSDataType(B_t)]; + MPSMatrixDescriptor* resultMatrixDesc = + [MPSMatrixDescriptor matrixDescriptorWithRows:aRows + columns:numberOfRightHandSides + matrices:1 + rowBytes:numberOfRightHandSides * aElemSize + matrixBytes:aRows * numberOfRightHandSides * aElemSize + dataType:getMPSDataType(result_t)]; + MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1 + columns:numPivots + matrices:1 + rowBytes:numPivots * sizeof(uint32_t) + matrixBytes:numPivots * sizeof(uint32_t) + dataType:MPSDataTypeUInt32]; + + for (const auto i : c10::irange(batchSize)) { + const uint64_t batchOffsetA = i * aRows * aCols; + const uint64_t batchOffsetB = i * aRows * numberOfRightHandSides; + MPSMatrix* mpsLU = [[[MPSMatrix alloc] initWithBuffer:luBuffer + offset:(LU_.storage_offset() + batchOffsetA) * aElemSize + descriptor:luMatrixDesc] autorelease]; + + MPSMatrix* mpsRHS = [[[MPSMatrix alloc] initWithBuffer:bBuffer + offset:(B_t.storage_offset() + batchOffsetB) * aElemSize + descriptor:rhsMatrixDesc] autorelease]; + + MPSMatrix* mpsResult = [[[MPSMatrix alloc] initWithBuffer:resultBuffer + offset:(result_t.storage_offset() + batchOffsetB) * aElemSize + descriptor:resultMatrixDesc] autorelease]; + + MPSMatrix* mpsPivots = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i]) + offset:0 + descriptor:pivotsMatrixDesc] autorelease]; + id statusBuffer = getMTLBufferStorage(status_tensors[i]); + [lu_decomp encodeToCommandBuffer:commandBuffer + sourceMatrix:mpsLU + resultMatrix:mpsLU + pivotIndices:mpsPivots + status:statusBuffer]; + [solver encodeToCommandBuffer:commandBuffer + sourceMatrix:mpsLU + rightHandSideMatrix:mpsRHS + pivotIndices:mpsPivots + solutionMatrix:mpsResult]; + } + } + }); + + auto stacked_status = A.dim() > 2 ? at::stack(status_tensors) : status_tensors[0]; + std::vector info_sizes(A.sizes().begin(), A.sizes().end() - 2); + info.copy_(stacked_status.view(info_sizes)); + + if (check_errors) { + for (const auto i : c10::irange(status_tensors.size())) { + int status = status_tensors[i].item(); + TORCH_CHECK(status == 0, + "solve(): Linear solve failed at the ", + i + 1, + " sample with status: ", + status, + ". See https://developer.apple.com/documentation/metalperformanceshaders/" + "mpsmatrixdecompositionstatus for details."); + } + } + if (!left) { + // If this was a right solve, transpose the result back + result.copy_(result_t.transpose(-2, -1).contiguous()); + } +} + static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { using namespace mps; static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); @@ -1119,6 +1295,18 @@ TORCH_IMPL_FUNC(triangular_solve_mps_out) result.copy_(out); } +TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps) +(const Tensor& A, + const Tensor& B, + bool left, + bool check_errors, + const Tensor& result, + const Tensor& LU, + const Tensor& pivots, + const Tensor& info) { + mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info); +} + std::tuple linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { Tensor info = at::empty({}, A.options().dtype(kInt)); mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 723fa1f9b7e..e412752a1dc 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14315,6 +14315,7 @@ structured: True dispatch: CPU, CUDA: _linalg_solve_ex_out + MPS: _linalg_solve_ex_out_mps - func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info) python_module: linalg diff --git a/test/test_mps.py b/test/test_mps.py index 12999221b1d..b4aa32d9e05 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -91,6 +91,9 @@ def mps_ops_grad_modifier(ops): 'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`. 'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`. 'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`. + 'linalg.solve': [torch.float16, torch.float32], # missing `aten::lu_solve`. + 'linalg.solve_ex': [torch.float16, torch.float32], # missing `aten::lu_solve`. + 'linalg.tensorsolve': [torch.float16, torch.float32], # missing `aten::lu_solve`. 'linalg.det': [torch.float16, torch.float32], # missing aten::lu_solve.out 'aminmax': [torch.float32, torch.float16], 'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half' @@ -715,10 +718,7 @@ def mps_ops_modifier(ops): 'linalg.normsubgradients_at_zero': [torch.float32], 'linalg.qr': None, 'linalg.slogdet': None, - 'linalg.solve': None, - 'linalg.solve_ex': None, 'linalg.svdvals': None, - 'linalg.tensorsolve': None, 'linalg.vecdot': None, 'logcumsumexp': None, 'logdet': None, @@ -2750,6 +2750,82 @@ class TestMPS(TestCaseMPS): run_lu_factor_ex_test(32, 10, 10, check_errors=False) run_lu_factor_ex_test(32, 2, 2, 2, 2, 10, 10, check_errors=True) + def test_linalg_solve(self): + from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values + + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32) + + def run_linalg_solve_test(size, *batch_dims): + A_cpu = make_arg(*batch_dims, size, size) + A_mps = A_cpu.to('mps') + + for left in [True, False]: + if left: + b_cpu = torch.randn(*batch_dims, size, 3, device='cpu', dtype=torch.float32) + else: + b_cpu = torch.randn(*batch_dims, 3, size, device='cpu', dtype=torch.float32) + + b_mps = b_cpu.to('mps') + + # Solve the system + X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left) + X_mps = torch.linalg.solve(A_mps, b_mps, left=left) + self.assertEqual(X_cpu, X_mps) + + # Test with transposed matrices + X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left) + X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left) + self.assertEqual(X_cpu_t, X_mps_t) + + # test with different even/odd matrix sizes + matrix_sizes = [1, 2, 3, 4] + # even/odd batch sizes + batch_sizes = [1, 2, 4] + + for size in matrix_sizes: + for batch_size in batch_sizes: + run_linalg_solve_test(size, batch_size) + + # test >3D matrices + run_linalg_solve_test(32, 10, 10) + run_linalg_solve_test(32, 2, 2, 2, 2, 10, 10) + + def test_linalg_solve_with_broadcasting(self): + from functools import partial + import torch + from torch.testing._internal.common_utils import ( + make_fullrank_matrices_with_distinct_singular_values + ) + + make_fullrank = make_fullrank_matrices_with_distinct_singular_values + make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32) + + batch_size = 4 + size = 3 + + A_cpu = make_arg(batch_size, size, size) + A_mps = A_cpu.to('mps') + + for left in [True, False]: + b_cpu = torch.randn(batch_size, size, device='cpu', dtype=torch.float32) + b_mps = b_cpu.to('mps') + + if left: + b_cpu = b_cpu.unsqueeze(-1) + b_mps = b_mps.unsqueeze(-1) + else: + b_cpu = b_cpu.view(batch_size, 1, size) + b_mps = b_mps.view(batch_size, 1, size) + + X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left) + X_mps = torch.linalg.solve(A_mps, b_mps, left=left) + self.assertEqual(X_cpu, X_mps) + + X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left) + X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left) + self.assertEqual(X_cpu_t, X_mps_t) + def test_linalg_det(self): from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values