diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index 85b82e3acd6..f5f1eb2fe9a 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -1,4 +1,5 @@ #include +#include using namespace metal; template @@ -31,6 +32,271 @@ kernel void naive_matmul( outputData[x * strides[2].x + y * strides[2].y] = rc; } +inline float blockReduceSum( + threadgroup float* sharedScratch, + float val, + uint tid, + uint tpg) { + sharedScratch[tid] = val; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint offset = tpg >> 1; offset > 0; offset >>= 1) { + if (tid < offset) { + sharedScratch[tid] += sharedScratch[tid + offset]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + return sharedScratch[0]; +} + +kernel void factorDiagonalBlock( + device float* A [[buffer(0)]], + device int* success [[buffer(1)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint tid [[thread_position_in_threadgroup]], + uint bid [[threadgroup_position_in_grid]], + uint tpg [[threads_per_threadgroup]]) { + const uint actSize = min(N - k * NB, NB); // uint64 before NB + const uint batch_offset = bid * N * N; + + const uint row0 = k * NB; + const uint col0 = k * NB; + + threadgroup float tile[32][33]; + threadgroup float reduceScratch[256]; + const uint tileSize = actSize * actSize; + + for (uint i = tid; i < tileSize; i += tpg) { + uint r = i / actSize; + uint c = i % actSize; + tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint kk = 0; kk < actSize; kk++) { + float diagElt = 0.0f; + if (kk > 0) { + float partialSum = 0.0f; + for (uint i = tid; i < kk; i += tpg) { + float val = tile[kk][i]; + partialSum = fma(val, val, partialSum); + } + diagElt = blockReduceSum(reduceScratch, partialSum, tid, tpg); + } + + if (tid == 0) { + float diagVal = tile[kk][kk] - diagElt; + // Check for positive definiteness + if (diagVal <= 0.0f) { + success[bid] = 0; // matrix is not positive definite + return; + } + tile[kk][kk] = sqrt(diagVal); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float pivot = tile[kk][kk]; + + for (uint j = kk + 1 + tid; j < actSize; j += tpg) { + float partialSum = 0.0f; + for (uint i = 0; i < kk; i++) { + partialSum = fma(tile[j][i], tile[kk][i], partialSum); + } + + float val = tile[j][kk]; + val -= partialSum; + val /= pivot; + tile[j][kk] = val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint i = tid; i < tileSize; i += tpg) { + uint r = i / actSize; + uint c = i % actSize; + A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c]; + } +} + +kernel void applyTRSM( + device float* A [[buffer(0)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]]) { + uint b = tgid.x; + uint idxJ = tgid.y; + + const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB))); + const uint batch_offset = b * N * N; + const uint j = (k + 1) + idxJ; + + uint row0 = j * NB; + uint col0 = k * NB; + + uint actSize_j = (uint)min((int)(N - row0), (int)NB); + if (actSize_k == 0 || actSize_j == 0) { + return; + } + if (j == k) { + return; + } + + threadgroup float diag[32 * 32]; + threadgroup float target[32 * 32]; + + for (uint i = tid.x; i < actSize_k * actSize_k; i += tpg.x) { + uint r = i / actSize_k; + uint c = i % actSize_k; + diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)]; + } + for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) { + uint r = i / actSize_k; + uint c = i % actSize_k; + target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint col = 0; col < actSize_k; col++) { + float diag_val = diag[col * actSize_k + col]; + if (abs(diag_val) < 1e-6f) { + diag_val = (diag_val < 0.0f) ? -1e-6f : 1e-6f; + } + + for (uint row = tid.x; row < actSize_j; row += tpg.x) { + float sum = target[row * actSize_k + col]; + + // kahan sum + float c = 0.0f; + for (uint p = 0; p < col; p++) { + float y = -target[row * actSize_k + p] * diag[col * actSize_k + p] - c; + float t = sum + y; + c = (t - sum) - y; + sum = t; + } + + target[row * actSize_k + col] = sum / diag_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) { + uint r = i / actSize_k; + uint c = i % actSize_k; + A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i]; + } +} + +kernel void applySYRK( + device float* A [[buffer(0)]], + constant uint& N [[buffer(2)]], + constant uint& NB [[buffer(3)]], + constant uint& k [[buffer(4)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 tgid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]]) { + uint b = tgid.x; + uint pairID = tgid.y; + + uint jRel = (-1 + sqrt(1 + 8 * float(pairID))) / 2; + uint hRel = pairID - (jRel * (jRel + 1) >> 1); + + const uint startJ = (k + 1); + uint j = startJ + jRel; + uint h = startJ + hRel; + uint row0 = j * NB; + uint col0 = h * NB; + + const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB))); + const uint actSize_j = min((uint)(N - row0), NB); + const uint actSize_h = min((uint)(N - col0), NB); + const uint batch_offset = b * N * N; + + if (actSize_j == 0 || actSize_h == 0 || actSize_k == 0) + return; + + threadgroup float left[32 * 33]; + threadgroup float right_t[32 * 33]; + threadgroup float tile[32 * 33]; + + const uint threads = min(tpg.x, actSize_j * actSize_k); + + for (uint i = tid.x; i < actSize_j * actSize_k; i += threads) { + uint r = i / actSize_k; + uint c = i % actSize_k; + left[r * actSize_k + c] = A[batch_offset + (j * NB + r) * N + (k * NB + c)]; + } + + for (uint i = tid.x; i < actSize_h * actSize_k; i += threads) { + uint r = i / actSize_k; + uint c = i % actSize_k; + right_t[c * actSize_h + r] = + A[batch_offset + (h * NB + r) * N + (k * NB + c)]; + } + + for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) { + uint r = i / actSize_h; + uint c = i % actSize_h; + tile[r * actSize_h + c] = A[batch_offset + (row0 + r) * N + (col0 + c)]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint idx = tid.x; idx < actSize_j * actSize_h; idx += threads) { + uint r = idx / actSize_h; + uint c = idx % actSize_h; + + if ((j == h) && (r < c)) + continue; + + uint tile_idx = r * actSize_h + c; + float sum = tile[tile_idx]; + + uint left_row = r * actSize_k; + uint right_col = c; + + uint k = 0; + float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f}; + + for (; k + 4 <= actSize_k; k += 4) { + float4 left4 = { + left[left_row + k], + left[left_row + k + 1], + left[left_row + k + 2], + left[left_row + k + 3]}; + + float4 right4 = { + right_t[(k + 0) * actSize_h + right_col], + right_t[(k + 1) * actSize_h + right_col], + right_t[(k + 2) * actSize_h + right_col], + right_t[(k + 3) * actSize_h + right_col]}; + + sum4 = fma(left4, right4, sum4); + } + + sum -= dot(sum4, 1.0); + + for (; k < actSize_k; k++) { + sum = fma(-left[left_row + k], right_t[k * actSize_h + right_col], sum); + } + + tile[tile_idx] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) { + uint r = i / actSize_h; + uint c = i % actSize_h; + A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r * actSize_h + c]; + } +} + #define INSTANTIATE_NAIVE_MM(DTYPE) \ template [[host_name("naive_matmul_" #DTYPE)]] kernel void \ naive_matmul( \ diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index fe77c1936a2..32e983238dc 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include #include @@ -780,6 +782,83 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A, return out; } +static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) { + using namespace mps; + + TORCH_CHECK(out.is_mps()); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32"); + TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D"); + TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square"); + + if (input.numel() == 0 || out.numel() == 0) { + out.zero_(); + return out; + } + resize_output(out, input.sizes()); + out.copy_(input); + + int64_t ndim = out.dim(); + int64_t N = out.size(-1); + int64_t B = 1; + for (int64_t i = 0; i < ndim - 2; i++) { + B *= out.size(i); + } + + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + + auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock"); + auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM"); + auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK"); + + int64_t NB = std::min(32, N); + int64_t numBlocks = (N + NB - 1) / NB; + + Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1); + id successBuffer = getMTLBufferStorage(success); + + MTLSize threadGroupSize = MTLSizeMake(256, 1, 1); + id outBuffer = getMTLBufferStorage(out); + id computeEncoder = stream->commandEncoder(); + [computeEncoder setBuffer:outBuffer offset:0 atIndex:0]; + [computeEncoder setBytes:&N length:sizeof(int64_t) atIndex:2]; + [computeEncoder setBytes:&NB length:sizeof(int64_t) atIndex:3]; + + @autoreleasepool { + dispatch_sync_with_rethrow(stream->queue(), ^() { + for (int64_t k = 0; k < numBlocks; k++) { + [computeEncoder setComputePipelineState:factorDiagonalPSO]; + [computeEncoder setBuffer:successBuffer offset:0 atIndex:1]; + [computeEncoder setBytes:&k length:sizeof(int64_t) atIndex:4]; + MTLSize gridSize = MTLSizeMake(B, 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // process all remaining blocks in this row/column in parallel + if (k < numBlocks - 1) { + int64_t startJ = k + 1; + int64_t nBlocksJ = (numBlocks - startJ); + + if (nBlocksJ > 0) { + // TRSM for all blocks in parallel + MTLSize trsmGridSize = MTLSizeMake(B, nBlocksJ, 1); + [computeEncoder setComputePipelineState:applyTRSMPSO]; + [computeEncoder dispatchThreadgroups:trsmGridSize threadsPerThreadgroup:threadGroupSize]; + + // SYRK for all independent block pairs in parallel + uint32_t nPairs = nBlocksJ * (nBlocksJ + 1) / 2; + MTLSize syrkGridSize = MTLSizeMake(B, nPairs, 1); + [computeEncoder setComputePipelineState:applySYRKPSO]; + [computeEncoder dispatchThreadgroups:syrkGridSize threadsPerThreadgroup:threadGroupSize]; + } + } + } + }); + } + + TORCH_CHECK(success.all().item(), "linalg.cholesky: Input matrix is not positive definite"); + out.tril_(); // + return upper ? out.transpose_(ndim - 2, ndim - 1) : out; +} } // namespace mps Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) { @@ -940,6 +1019,25 @@ Tensor& addbmm_out_mps(const Tensor& self, return result; } +Tensor cholesky_mps(const Tensor& self, bool upper) { + auto out = at::empty_like(self, MemoryFormat::Contiguous); + mps::linalg_cholesky_mps_impl(self, upper, out); + return out; +} + +Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) { + return mps::linalg_cholesky_mps_impl(self, upper, out); +} + +Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) { + return mps::linalg_cholesky_mps_impl(self, upper, out); +} + +Tensor linalg_cholesky_mps(const Tensor& self, bool upper) { + auto out = at::empty_like(self, MemoryFormat::Contiguous); + return mps::linalg_cholesky_mps_impl(self, upper, out); +} + Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4ce561bb2a9..2eb5e094188 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9439,11 +9439,13 @@ - func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: cholesky_out + MPS: cholesky_mps_out - func: cholesky(Tensor self, bool upper=False) -> Tensor variants: method, function dispatch: CPU, CUDA: cholesky + MPS: cholesky_mps - func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13900,9 +13902,15 @@ - func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor python_module: linalg + dispatch: + CompositeImplicitAutograd: linalg_cholesky + MPS: linalg_cholesky_mps - func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) python_module: linalg + dispatch: + CompositeImplicitAutograd: linalg_cholesky_out + MPS: linalg_cholesky_out_mps - func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor python_module: linalg diff --git a/test/test_mps.py b/test/test_mps.py index 2eb4eedf80b..712c939135b 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -673,7 +673,6 @@ def mps_ops_modifier(ops): '__rsub__': None, 'cauchy_': None, 'cauchy': None, - 'cholesky': None, 'cholesky_inverse': None, 'cholesky_solve': None, 'cummax': None, @@ -693,7 +692,6 @@ def mps_ops_modifier(ops): 'index_reduceamin': None, 'kthvalue': None, 'lcm': None, - 'linalg.cholesky': None, 'linalg.cholesky_ex': None, 'linalg.cond': None, 'linalg.det': None, @@ -6388,6 +6386,30 @@ class TestMPS(TestCaseMPS): atol=0, rtol=0 ) + def test_cholesky(self): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + def run_cholesky_test(size, *batch_dims, upper): + input_cpu = random_hermitian_pd_matrix(size, *batch_dims, dtype=torch.float32, device="cpu") + input_mps = input_cpu.to('mps') + output_cpu = torch.linalg.cholesky(input_cpu, upper=upper) + output_mps = torch.linalg.cholesky(input_mps, upper=upper) + self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6) + + # test with different even/odd matrix sizes + matrix_sizes = [1, 2, 3, 4, 8, 17, 64, 128, 154] + # even/odd batch sizes + batch_sizes = [1, 2, 4, 8, 16, 17] + + for upper in [True, False]: + for size in matrix_sizes: + for batch_size in batch_sizes: + run_cholesky_test(size, batch_size, upper=upper) + + # test >3D matrices + run_cholesky_test(128, 10, 10, upper=False) + run_cholesky_test(128, 2, 2, 2, 2, 10, 10, upper=True) + def test_upsample_nearest2d(self): def helper(N, C, H, W, memory_format): inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fa77b906b1b..595f2757d5c 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -410,6 +410,11 @@ self: cholesky_backward(grad, upper, L) L: cholesky_jvp(self_t, L, upper) +# temporarily here before linalg_cholesky dispatches to linalg_cholesky_ex on MPS device +- name: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor + self: cholesky_backward(grad, upper, result) + result: cholesky_jvp(self_t, result, upper) + - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask) result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper)