mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] linalg solve implementation (#146531)
Fixes #98222 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146531 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
495049860b
commit
0dc03134d9
4 changed files with 272 additions and 6 deletions
|
|
@ -590,15 +590,16 @@ TORCH_META_FUNC(_linalg_solve_ex)(const Tensor& A,
|
||||||
TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)");
|
TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)");
|
||||||
auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1)
|
auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1)
|
||||||
: B_broad_shape;
|
: B_broad_shape;
|
||||||
auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/left);
|
// row major for mps implementation
|
||||||
|
auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/A.device().type() != at::kMPS? left : false);
|
||||||
|
|
||||||
set_output_strided(0, result_shape, result_strides, B.options(), {});
|
set_output_strided(0, result_shape, result_strides, B.options(), {});
|
||||||
|
|
||||||
auto shape = A.sizes();
|
auto shape = A.sizes();
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
|
|
||||||
// LU
|
// LU, row major for mps
|
||||||
auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
|
auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/A.device().type() != at::kMPS? true : false);
|
||||||
set_output_strided(1, shape, LU_strides, A.options(), {});
|
set_output_strided(1, shape, LU_strides, A.options(), {});
|
||||||
|
|
||||||
// pivots
|
// pivots
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
|
#include <ATen/ops/_linalg_solve_ex_native.h>
|
||||||
#include <ATen/ops/addbmm_native.h>
|
#include <ATen/ops/addbmm_native.h>
|
||||||
#include <ATen/ops/addmm_native.h>
|
#include <ATen/ops/addmm_native.h>
|
||||||
#include <ATen/ops/addr_native.h>
|
#include <ATen/ops/addr_native.h>
|
||||||
|
|
@ -241,6 +242,181 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void linalg_solve_out_mps_impl(const at::Tensor& A,
|
||||||
|
const at::Tensor& B,
|
||||||
|
bool left,
|
||||||
|
bool check_errors,
|
||||||
|
const at::Tensor& result,
|
||||||
|
const at::Tensor& LU,
|
||||||
|
const at::Tensor& pivots,
|
||||||
|
const at::Tensor& info) {
|
||||||
|
using namespace mps;
|
||||||
|
|
||||||
|
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
|
||||||
|
"linalg.lu_factor(): MPS doesn't support complex types.");
|
||||||
|
Tensor A_t, B_t;
|
||||||
|
// If 'left' is false, reinterpret the problem so that Ax = B becomes A^T ⋅ (x^T) = B^T
|
||||||
|
// Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output
|
||||||
|
if (left) {
|
||||||
|
A_t = A.contiguous();
|
||||||
|
B_t = B.contiguous();
|
||||||
|
} else {
|
||||||
|
A_t = A.transpose(-2, -1).contiguous();
|
||||||
|
B_t = B.transpose(-2, -1).contiguous();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t aRows = A_t.size(-2);
|
||||||
|
uint64_t aCols = A_t.size(-1);
|
||||||
|
uint64_t aElemSize = A_t.element_size();
|
||||||
|
int a_ndim = A_t.dim();
|
||||||
|
int b_ndim = B_t.dim();
|
||||||
|
int numberOfRightHandSides = (b_ndim == a_ndim - 1) ? 1 : (b_ndim >= 2 ? B_t.size(-1) : 1);
|
||||||
|
|
||||||
|
uint64_t numPivots = std::min(aRows, aCols);
|
||||||
|
std::vector<int64_t> pivot_sizes(A_t.sizes().begin(), A_t.sizes().end() - 2);
|
||||||
|
info.fill_(0); // will be set to 1 during kernel if something fails
|
||||||
|
resize_output(info, pivot_sizes);
|
||||||
|
pivot_sizes.push_back(numPivots);
|
||||||
|
resize_output(pivots, pivot_sizes);
|
||||||
|
|
||||||
|
if (A_t.numel() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (A_t.dim() > 3) {
|
||||||
|
A_t = A_t.flatten(0, -3);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t batchSize = (A_t.dim() > 2) ? A_t.size(0) : 1;
|
||||||
|
std::vector<Tensor> status_tensors;
|
||||||
|
std::vector<Tensor> pivots_list;
|
||||||
|
|
||||||
|
status_tensors.reserve(batchSize);
|
||||||
|
pivots_list.reserve(batchSize);
|
||||||
|
for ([[maybe_unused]] const auto i : c10::irange(batchSize)) {
|
||||||
|
status_tensors.push_back(at::zeros(1, kInt, std::nullopt, kMPS, std::nullopt));
|
||||||
|
pivots_list.push_back(at::zeros(numPivots, kInt, std::nullopt, kMPS, std::nullopt));
|
||||||
|
}
|
||||||
|
|
||||||
|
resize_output(LU, A_t.sizes());
|
||||||
|
Tensor LU_ = LU;
|
||||||
|
if (!LU_.is_same(A_t)) {
|
||||||
|
A_t = LU_.copy_(A_t);
|
||||||
|
} else {
|
||||||
|
A_t = LU_;
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(A_t.is_contiguous());
|
||||||
|
|
||||||
|
Tensor result_t;
|
||||||
|
if (!left) {
|
||||||
|
// For right solve, we'll need to transpose the result back later
|
||||||
|
result_t = at::empty_like(B_t, B_t.options());
|
||||||
|
} else {
|
||||||
|
result_t = result;
|
||||||
|
}
|
||||||
|
id<MTLBuffer> luBuffer = getMTLBufferStorage(LU_);
|
||||||
|
id<MTLBuffer> bBuffer = getMTLBufferStorage(B_t);
|
||||||
|
id<MTLBuffer> resultBuffer = getMTLBufferStorage(result_t);
|
||||||
|
|
||||||
|
MPSStream* mpsStream = getCurrentMPSStream();
|
||||||
|
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||||
|
|
||||||
|
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||||
|
|
||||||
|
MPSMatrixDecompositionLU* lu_decomp = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device
|
||||||
|
rows:aRows
|
||||||
|
columns:aCols] autorelease];
|
||||||
|
|
||||||
|
MPSMatrixSolveLU* solver = [[[MPSMatrixSolveLU alloc] initWithDevice:device
|
||||||
|
transpose:false
|
||||||
|
order:aRows
|
||||||
|
numberOfRightHandSides:numberOfRightHandSides] autorelease];
|
||||||
|
|
||||||
|
MPSMatrixDescriptor* luMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
|
||||||
|
columns:aCols
|
||||||
|
matrices:1
|
||||||
|
rowBytes:aCols * aElemSize
|
||||||
|
matrixBytes:aRows * aCols * aElemSize
|
||||||
|
dataType:getMPSDataType(LU_)];
|
||||||
|
MPSMatrixDescriptor* rhsMatrixDesc =
|
||||||
|
[MPSMatrixDescriptor matrixDescriptorWithRows:aRows
|
||||||
|
columns:numberOfRightHandSides
|
||||||
|
matrices:1
|
||||||
|
rowBytes:numberOfRightHandSides * aElemSize
|
||||||
|
matrixBytes:aRows * numberOfRightHandSides * aElemSize
|
||||||
|
dataType:getMPSDataType(B_t)];
|
||||||
|
MPSMatrixDescriptor* resultMatrixDesc =
|
||||||
|
[MPSMatrixDescriptor matrixDescriptorWithRows:aRows
|
||||||
|
columns:numberOfRightHandSides
|
||||||
|
matrices:1
|
||||||
|
rowBytes:numberOfRightHandSides * aElemSize
|
||||||
|
matrixBytes:aRows * numberOfRightHandSides * aElemSize
|
||||||
|
dataType:getMPSDataType(result_t)];
|
||||||
|
MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1
|
||||||
|
columns:numPivots
|
||||||
|
matrices:1
|
||||||
|
rowBytes:numPivots * sizeof(uint32_t)
|
||||||
|
matrixBytes:numPivots * sizeof(uint32_t)
|
||||||
|
dataType:MPSDataTypeUInt32];
|
||||||
|
|
||||||
|
for (const auto i : c10::irange(batchSize)) {
|
||||||
|
const uint64_t batchOffsetA = i * aRows * aCols;
|
||||||
|
const uint64_t batchOffsetB = i * aRows * numberOfRightHandSides;
|
||||||
|
MPSMatrix* mpsLU = [[[MPSMatrix alloc] initWithBuffer:luBuffer
|
||||||
|
offset:(LU_.storage_offset() + batchOffsetA) * aElemSize
|
||||||
|
descriptor:luMatrixDesc] autorelease];
|
||||||
|
|
||||||
|
MPSMatrix* mpsRHS = [[[MPSMatrix alloc] initWithBuffer:bBuffer
|
||||||
|
offset:(B_t.storage_offset() + batchOffsetB) * aElemSize
|
||||||
|
descriptor:rhsMatrixDesc] autorelease];
|
||||||
|
|
||||||
|
MPSMatrix* mpsResult = [[[MPSMatrix alloc] initWithBuffer:resultBuffer
|
||||||
|
offset:(result_t.storage_offset() + batchOffsetB) * aElemSize
|
||||||
|
descriptor:resultMatrixDesc] autorelease];
|
||||||
|
|
||||||
|
MPSMatrix* mpsPivots = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i])
|
||||||
|
offset:0
|
||||||
|
descriptor:pivotsMatrixDesc] autorelease];
|
||||||
|
id<MTLBuffer> statusBuffer = getMTLBufferStorage(status_tensors[i]);
|
||||||
|
[lu_decomp encodeToCommandBuffer:commandBuffer
|
||||||
|
sourceMatrix:mpsLU
|
||||||
|
resultMatrix:mpsLU
|
||||||
|
pivotIndices:mpsPivots
|
||||||
|
status:statusBuffer];
|
||||||
|
[solver encodeToCommandBuffer:commandBuffer
|
||||||
|
sourceMatrix:mpsLU
|
||||||
|
rightHandSideMatrix:mpsRHS
|
||||||
|
pivotIndices:mpsPivots
|
||||||
|
solutionMatrix:mpsResult];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
auto stacked_status = A.dim() > 2 ? at::stack(status_tensors) : status_tensors[0];
|
||||||
|
std::vector<int64_t> info_sizes(A.sizes().begin(), A.sizes().end() - 2);
|
||||||
|
info.copy_(stacked_status.view(info_sizes));
|
||||||
|
|
||||||
|
if (check_errors) {
|
||||||
|
for (const auto i : c10::irange(status_tensors.size())) {
|
||||||
|
int status = status_tensors[i].item<int>();
|
||||||
|
TORCH_CHECK(status == 0,
|
||||||
|
"solve(): Linear solve failed at the ",
|
||||||
|
i + 1,
|
||||||
|
" sample with status: ",
|
||||||
|
status,
|
||||||
|
". See https://developer.apple.com/documentation/metalperformanceshaders/"
|
||||||
|
"mpsmatrixdecompositionstatus for details.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!left) {
|
||||||
|
// If this was a right solve, transpose the result back
|
||||||
|
result.copy_(result_t.transpose(-2, -1).contiguous());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||||
using namespace mps;
|
using namespace mps;
|
||||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||||
|
|
@ -1119,6 +1295,18 @@ TORCH_IMPL_FUNC(triangular_solve_mps_out)
|
||||||
result.copy_(out);
|
result.copy_(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
|
||||||
|
(const Tensor& A,
|
||||||
|
const Tensor& B,
|
||||||
|
bool left,
|
||||||
|
bool check_errors,
|
||||||
|
const Tensor& result,
|
||||||
|
const Tensor& LU,
|
||||||
|
const Tensor& pivots,
|
||||||
|
const Tensor& info) {
|
||||||
|
mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
|
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
|
||||||
Tensor info = at::empty({}, A.options().dtype(kInt));
|
Tensor info = at::empty({}, A.options().dtype(kInt));
|
||||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
|
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
|
||||||
|
|
|
||||||
|
|
@ -14315,6 +14315,7 @@
|
||||||
structured: True
|
structured: True
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: _linalg_solve_ex_out
|
CPU, CUDA: _linalg_solve_ex_out
|
||||||
|
MPS: _linalg_solve_ex_out_mps
|
||||||
|
|
||||||
- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)
|
- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)
|
||||||
python_module: linalg
|
python_module: linalg
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,9 @@ def mps_ops_grad_modifier(ops):
|
||||||
'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
||||||
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
||||||
'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
||||||
|
'linalg.solve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
|
||||||
|
'linalg.solve_ex': [torch.float16, torch.float32], # missing `aten::lu_solve`.
|
||||||
|
'linalg.tensorsolve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
|
||||||
'linalg.det': [torch.float16, torch.float32], # missing aten::lu_solve.out
|
'linalg.det': [torch.float16, torch.float32], # missing aten::lu_solve.out
|
||||||
'aminmax': [torch.float32, torch.float16],
|
'aminmax': [torch.float32, torch.float16],
|
||||||
'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half'
|
'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half'
|
||||||
|
|
@ -715,10 +718,7 @@ def mps_ops_modifier(ops):
|
||||||
'linalg.normsubgradients_at_zero': [torch.float32],
|
'linalg.normsubgradients_at_zero': [torch.float32],
|
||||||
'linalg.qr': None,
|
'linalg.qr': None,
|
||||||
'linalg.slogdet': None,
|
'linalg.slogdet': None,
|
||||||
'linalg.solve': None,
|
|
||||||
'linalg.solve_ex': None,
|
|
||||||
'linalg.svdvals': None,
|
'linalg.svdvals': None,
|
||||||
'linalg.tensorsolve': None,
|
|
||||||
'linalg.vecdot': None,
|
'linalg.vecdot': None,
|
||||||
'logcumsumexp': None,
|
'logcumsumexp': None,
|
||||||
'logdet': None,
|
'logdet': None,
|
||||||
|
|
@ -2750,6 +2750,82 @@ class TestMPS(TestCaseMPS):
|
||||||
run_lu_factor_ex_test(32, 10, 10, check_errors=False)
|
run_lu_factor_ex_test(32, 10, 10, check_errors=False)
|
||||||
run_lu_factor_ex_test(32, 2, 2, 2, 2, 10, 10, check_errors=True)
|
run_lu_factor_ex_test(32, 2, 2, 2, 2, 10, 10, check_errors=True)
|
||||||
|
|
||||||
|
def test_linalg_solve(self):
|
||||||
|
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
|
||||||
|
|
||||||
|
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
|
||||||
|
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
|
||||||
|
|
||||||
|
def run_linalg_solve_test(size, *batch_dims):
|
||||||
|
A_cpu = make_arg(*batch_dims, size, size)
|
||||||
|
A_mps = A_cpu.to('mps')
|
||||||
|
|
||||||
|
for left in [True, False]:
|
||||||
|
if left:
|
||||||
|
b_cpu = torch.randn(*batch_dims, size, 3, device='cpu', dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
b_cpu = torch.randn(*batch_dims, 3, size, device='cpu', dtype=torch.float32)
|
||||||
|
|
||||||
|
b_mps = b_cpu.to('mps')
|
||||||
|
|
||||||
|
# Solve the system
|
||||||
|
X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left)
|
||||||
|
X_mps = torch.linalg.solve(A_mps, b_mps, left=left)
|
||||||
|
self.assertEqual(X_cpu, X_mps)
|
||||||
|
|
||||||
|
# Test with transposed matrices
|
||||||
|
X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left)
|
||||||
|
X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left)
|
||||||
|
self.assertEqual(X_cpu_t, X_mps_t)
|
||||||
|
|
||||||
|
# test with different even/odd matrix sizes
|
||||||
|
matrix_sizes = [1, 2, 3, 4]
|
||||||
|
# even/odd batch sizes
|
||||||
|
batch_sizes = [1, 2, 4]
|
||||||
|
|
||||||
|
for size in matrix_sizes:
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
run_linalg_solve_test(size, batch_size)
|
||||||
|
|
||||||
|
# test >3D matrices
|
||||||
|
run_linalg_solve_test(32, 10, 10)
|
||||||
|
run_linalg_solve_test(32, 2, 2, 2, 2, 10, 10)
|
||||||
|
|
||||||
|
def test_linalg_solve_with_broadcasting(self):
|
||||||
|
from functools import partial
|
||||||
|
import torch
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
make_fullrank_matrices_with_distinct_singular_values
|
||||||
|
)
|
||||||
|
|
||||||
|
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
|
||||||
|
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
size = 3
|
||||||
|
|
||||||
|
A_cpu = make_arg(batch_size, size, size)
|
||||||
|
A_mps = A_cpu.to('mps')
|
||||||
|
|
||||||
|
for left in [True, False]:
|
||||||
|
b_cpu = torch.randn(batch_size, size, device='cpu', dtype=torch.float32)
|
||||||
|
b_mps = b_cpu.to('mps')
|
||||||
|
|
||||||
|
if left:
|
||||||
|
b_cpu = b_cpu.unsqueeze(-1)
|
||||||
|
b_mps = b_mps.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
b_cpu = b_cpu.view(batch_size, 1, size)
|
||||||
|
b_mps = b_mps.view(batch_size, 1, size)
|
||||||
|
|
||||||
|
X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left)
|
||||||
|
X_mps = torch.linalg.solve(A_mps, b_mps, left=left)
|
||||||
|
self.assertEqual(X_cpu, X_mps)
|
||||||
|
|
||||||
|
X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left)
|
||||||
|
X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left)
|
||||||
|
self.assertEqual(X_cpu_t, X_mps_t)
|
||||||
|
|
||||||
def test_linalg_det(self):
|
def test_linalg_det(self):
|
||||||
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
|
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue