From 2328dcccb9e8978f3e7b6d2fc0a7482fc37c5ed8 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 7 Feb 2025 20:27:33 -0800 Subject: [PATCH] [MPSInductor] Implement Welford reduction (#146703) Still work in progress, though fallback works as expected, but custom shader is not Pull Request resolved: https://github.com/pytorch/pytorch/pull/146703 Approved by: https://github.com/jansel, https://github.com/dcci --- c10/metal/reduction_utils.h | 12 ++++++++++++ test/inductor/test_mps_basic.py | 1 + torch/_inductor/codegen/mps.py | 12 +++++++++++- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index bfc0af6f526..1dd7b78b1a7 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -29,6 +29,18 @@ opmath_t threadgroup_prod(threadgroup T* data, unsigned size) { return rc; } +template +float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) { + float m = data[0]; + float m2 = 0; + for (unsigned idx = 1; idx < size; ++idx) { + float delta = data[idx] - m; + m += delta / (idx + 1); + m2 += delta * (data[idx] - m); + } + return float2(m, m2); +} + template T threadgroup_max(threadgroup T* data, unsigned size) { // TODO: This should be moved to the callee diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 4257c9687bd..420a3f97c43 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -164,6 +164,7 @@ for test_name in [ "test_inf", "test_isinf", "test_isinf2", + "test_layer_norm", "test_lgamma", "test_linear_float64", "test_log_fp64", diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 241f8f94ca4..86cbb6f5361 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -12,7 +12,7 @@ from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_ from torch.utils._sympy.value_ranges import ValueRanges from ..utils import get_bounds_index_expr, get_kernel_metadata -from ..virtualized import ops, V +from ..virtualized import ops, OpsWrapper, V from .common import ( CSEVariable, DeferredLine, @@ -463,6 +463,16 @@ class MetalKernel(SIMDKernel): f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})", dtype=dtype, ) + if reduction_type == "welford_reduce": + acc_buf = self._new_accvar(src_dtype, reduction_dim.numel) + self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};") + wf_res = self.cse.generate( + self.body, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})", + ) + return OpsWrapper._unwrap( + (f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel) + ) raise NotImplementedError(reduction_type) def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: