[MPS] linalg solve implementation (#146531)

Fixes #98222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146531
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20 2025-02-06 00:57:49 +00:00 committed by PyTorch MergeBot
parent 495049860b
commit 0dc03134d9
4 changed files with 272 additions and 6 deletions

View file

@ -590,15 +590,16 @@ TORCH_META_FUNC(_linalg_solve_ex)(const Tensor& A,
TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)");
auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1)
: B_broad_shape;
auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/left);
// row major for mps implementation
auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/A.device().type() != at::kMPS? left : false);
set_output_strided(0, result_shape, result_strides, B.options(), {});
auto shape = A.sizes();
auto ndim = shape.size();
// LU
auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
// LU, row major for mps
auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/A.device().type() != at::kMPS? true : false);
set_output_strided(1, shape, LU_strides, A.options(), {});
// pivots

View file

@ -13,6 +13,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_linalg_solve_ex_native.h>
#include <ATen/ops/addbmm_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addr_native.h>
@ -241,6 +242,181 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
}
}
static void linalg_solve_out_mps_impl(const at::Tensor& A,
const at::Tensor& B,
bool left,
bool check_errors,
const at::Tensor& result,
const at::Tensor& LU,
const at::Tensor& pivots,
const at::Tensor& info) {
using namespace mps;
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
"linalg.lu_factor(): MPS doesn't support complex types.");
Tensor A_t, B_t;
// If 'left' is false, reinterpret the problem so that Ax = B becomes A^T ⋅ (x^T) = B^T
// Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output
if (left) {
A_t = A.contiguous();
B_t = B.contiguous();
} else {
A_t = A.transpose(-2, -1).contiguous();
B_t = B.transpose(-2, -1).contiguous();
}
uint64_t aRows = A_t.size(-2);
uint64_t aCols = A_t.size(-1);
uint64_t aElemSize = A_t.element_size();
int a_ndim = A_t.dim();
int b_ndim = B_t.dim();
int numberOfRightHandSides = (b_ndim == a_ndim - 1) ? 1 : (b_ndim >= 2 ? B_t.size(-1) : 1);
uint64_t numPivots = std::min(aRows, aCols);
std::vector<int64_t> pivot_sizes(A_t.sizes().begin(), A_t.sizes().end() - 2);
info.fill_(0); // will be set to 1 during kernel if something fails
resize_output(info, pivot_sizes);
pivot_sizes.push_back(numPivots);
resize_output(pivots, pivot_sizes);
if (A_t.numel() == 0) {
return;
}
if (A_t.dim() > 3) {
A_t = A_t.flatten(0, -3);
}
uint64_t batchSize = (A_t.dim() > 2) ? A_t.size(0) : 1;
std::vector<Tensor> status_tensors;
std::vector<Tensor> pivots_list;
status_tensors.reserve(batchSize);
pivots_list.reserve(batchSize);
for ([[maybe_unused]] const auto i : c10::irange(batchSize)) {
status_tensors.push_back(at::zeros(1, kInt, std::nullopt, kMPS, std::nullopt));
pivots_list.push_back(at::zeros(numPivots, kInt, std::nullopt, kMPS, std::nullopt));
}
resize_output(LU, A_t.sizes());
Tensor LU_ = LU;
if (!LU_.is_same(A_t)) {
A_t = LU_.copy_(A_t);
} else {
A_t = LU_;
}
TORCH_INTERNAL_ASSERT(A_t.is_contiguous());
Tensor result_t;
if (!left) {
// For right solve, we'll need to transpose the result back later
result_t = at::empty_like(B_t, B_t.options());
} else {
result_t = result;
}
id<MTLBuffer> luBuffer = getMTLBufferStorage(LU_);
id<MTLBuffer> bBuffer = getMTLBufferStorage(B_t);
id<MTLBuffer> resultBuffer = getMTLBufferStorage(result_t);
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
MPSMatrixDecompositionLU* lu_decomp = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device
rows:aRows
columns:aCols] autorelease];
MPSMatrixSolveLU* solver = [[[MPSMatrixSolveLU alloc] initWithDevice:device
transpose:false
order:aRows
numberOfRightHandSides:numberOfRightHandSides] autorelease];
MPSMatrixDescriptor* luMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
columns:aCols
matrices:1
rowBytes:aCols * aElemSize
matrixBytes:aRows * aCols * aElemSize
dataType:getMPSDataType(LU_)];
MPSMatrixDescriptor* rhsMatrixDesc =
[MPSMatrixDescriptor matrixDescriptorWithRows:aRows
columns:numberOfRightHandSides
matrices:1
rowBytes:numberOfRightHandSides * aElemSize
matrixBytes:aRows * numberOfRightHandSides * aElemSize
dataType:getMPSDataType(B_t)];
MPSMatrixDescriptor* resultMatrixDesc =
[MPSMatrixDescriptor matrixDescriptorWithRows:aRows
columns:numberOfRightHandSides
matrices:1
rowBytes:numberOfRightHandSides * aElemSize
matrixBytes:aRows * numberOfRightHandSides * aElemSize
dataType:getMPSDataType(result_t)];
MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1
columns:numPivots
matrices:1
rowBytes:numPivots * sizeof(uint32_t)
matrixBytes:numPivots * sizeof(uint32_t)
dataType:MPSDataTypeUInt32];
for (const auto i : c10::irange(batchSize)) {
const uint64_t batchOffsetA = i * aRows * aCols;
const uint64_t batchOffsetB = i * aRows * numberOfRightHandSides;
MPSMatrix* mpsLU = [[[MPSMatrix alloc] initWithBuffer:luBuffer
offset:(LU_.storage_offset() + batchOffsetA) * aElemSize
descriptor:luMatrixDesc] autorelease];
MPSMatrix* mpsRHS = [[[MPSMatrix alloc] initWithBuffer:bBuffer
offset:(B_t.storage_offset() + batchOffsetB) * aElemSize
descriptor:rhsMatrixDesc] autorelease];
MPSMatrix* mpsResult = [[[MPSMatrix alloc] initWithBuffer:resultBuffer
offset:(result_t.storage_offset() + batchOffsetB) * aElemSize
descriptor:resultMatrixDesc] autorelease];
MPSMatrix* mpsPivots = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i])
offset:0
descriptor:pivotsMatrixDesc] autorelease];
id<MTLBuffer> statusBuffer = getMTLBufferStorage(status_tensors[i]);
[lu_decomp encodeToCommandBuffer:commandBuffer
sourceMatrix:mpsLU
resultMatrix:mpsLU
pivotIndices:mpsPivots
status:statusBuffer];
[solver encodeToCommandBuffer:commandBuffer
sourceMatrix:mpsLU
rightHandSideMatrix:mpsRHS
pivotIndices:mpsPivots
solutionMatrix:mpsResult];
}
}
});
auto stacked_status = A.dim() > 2 ? at::stack(status_tensors) : status_tensors[0];
std::vector<int64_t> info_sizes(A.sizes().begin(), A.sizes().end() - 2);
info.copy_(stacked_status.view(info_sizes));
if (check_errors) {
for (const auto i : c10::irange(status_tensors.size())) {
int status = status_tensors[i].item<int>();
TORCH_CHECK(status == 0,
"solve(): Linear solve failed at the ",
i + 1,
" sample with status: ",
status,
". See https://developer.apple.com/documentation/metalperformanceshaders/"
"mpsmatrixdecompositionstatus for details.");
}
}
if (!left) {
// If this was a right solve, transpose the result back
result.copy_(result_t.transpose(-2, -1).contiguous());
}
}
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
using namespace mps;
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
@ -1119,6 +1295,18 @@ TORCH_IMPL_FUNC(triangular_solve_mps_out)
result.copy_(out);
}
TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
(const Tensor& A,
const Tensor& B,
bool left,
bool check_errors,
const Tensor& result,
const Tensor& LU,
const Tensor& pivots,
const Tensor& info) {
mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
}
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
Tensor info = at::empty({}, A.options().dtype(kInt));
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);

View file

@ -14315,6 +14315,7 @@
structured: True
dispatch:
CPU, CUDA: _linalg_solve_ex_out
MPS: _linalg_solve_ex_out_mps
- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)
python_module: linalg

View file

@ -91,6 +91,9 @@ def mps_ops_grad_modifier(ops):
'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.solve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
'linalg.solve_ex': [torch.float16, torch.float32], # missing `aten::lu_solve`.
'linalg.tensorsolve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
'linalg.det': [torch.float16, torch.float32], # missing aten::lu_solve.out
'aminmax': [torch.float32, torch.float16],
'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half'
@ -715,10 +718,7 @@ def mps_ops_modifier(ops):
'linalg.normsubgradients_at_zero': [torch.float32],
'linalg.qr': None,
'linalg.slogdet': None,
'linalg.solve': None,
'linalg.solve_ex': None,
'linalg.svdvals': None,
'linalg.tensorsolve': None,
'linalg.vecdot': None,
'logcumsumexp': None,
'logdet': None,
@ -2750,6 +2750,82 @@ class TestMPS(TestCaseMPS):
run_lu_factor_ex_test(32, 10, 10, check_errors=False)
run_lu_factor_ex_test(32, 2, 2, 2, 2, 10, 10, check_errors=True)
def test_linalg_solve(self):
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
def run_linalg_solve_test(size, *batch_dims):
A_cpu = make_arg(*batch_dims, size, size)
A_mps = A_cpu.to('mps')
for left in [True, False]:
if left:
b_cpu = torch.randn(*batch_dims, size, 3, device='cpu', dtype=torch.float32)
else:
b_cpu = torch.randn(*batch_dims, 3, size, device='cpu', dtype=torch.float32)
b_mps = b_cpu.to('mps')
# Solve the system
X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left)
X_mps = torch.linalg.solve(A_mps, b_mps, left=left)
self.assertEqual(X_cpu, X_mps)
# Test with transposed matrices
X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left)
X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left)
self.assertEqual(X_cpu_t, X_mps_t)
# test with different even/odd matrix sizes
matrix_sizes = [1, 2, 3, 4]
# even/odd batch sizes
batch_sizes = [1, 2, 4]
for size in matrix_sizes:
for batch_size in batch_sizes:
run_linalg_solve_test(size, batch_size)
# test >3D matrices
run_linalg_solve_test(32, 10, 10)
run_linalg_solve_test(32, 2, 2, 2, 2, 10, 10)
def test_linalg_solve_with_broadcasting(self):
from functools import partial
import torch
from torch.testing._internal.common_utils import (
make_fullrank_matrices_with_distinct_singular_values
)
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
batch_size = 4
size = 3
A_cpu = make_arg(batch_size, size, size)
A_mps = A_cpu.to('mps')
for left in [True, False]:
b_cpu = torch.randn(batch_size, size, device='cpu', dtype=torch.float32)
b_mps = b_cpu.to('mps')
if left:
b_cpu = b_cpu.unsqueeze(-1)
b_mps = b_mps.unsqueeze(-1)
else:
b_cpu = b_cpu.view(batch_size, 1, size)
b_mps = b_mps.view(batch_size, 1, size)
X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left)
X_mps = torch.linalg.solve(A_mps, b_mps, left=left)
self.assertEqual(X_cpu, X_mps)
X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left)
X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left)
self.assertEqual(X_cpu_t, X_mps_t)
def test_linalg_det(self):
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values