[MPSInductor] Implement Welford reduction (#146703)

Still work in progress, though fallback works as expected, but custom shader is not

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146703
Approved by: https://github.com/jansel, https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-02-07 20:27:33 -08:00 committed by PyTorch MergeBot
parent 69feef5a94
commit 2328dcccb9
3 changed files with 24 additions and 1 deletions

View file

@ -29,6 +29,18 @@ opmath_t<T> threadgroup_prod(threadgroup T* data, unsigned size) {
return rc;
}
template <typename T>
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
float m = data[0];
float m2 = 0;
for (unsigned idx = 1; idx < size; ++idx) {
float delta = data[idx] - m;
m += delta / (idx + 1);
m2 += delta * (data[idx] - m);
}
return float2(m, m2);
}
template <typename T>
T threadgroup_max(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee

View file

@ -164,6 +164,7 @@ for test_name in [
"test_inf",
"test_isinf",
"test_isinf2",
"test_layer_norm",
"test_lgamma",
"test_linear_float64",
"test_log_fp64",

View file

@ -12,7 +12,7 @@ from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_
from torch.utils._sympy.value_ranges import ValueRanges
from ..utils import get_bounds_index_expr, get_kernel_metadata
from ..virtualized import ops, V
from ..virtualized import ops, OpsWrapper, V
from .common import (
CSEVariable,
DeferredLine,
@ -463,6 +463,16 @@ class MetalKernel(SIMDKernel):
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
dtype=dtype,
)
if reduction_type == "welford_reduce":
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
wf_res = self.cse.generate(
self.body,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
)
return OpsWrapper._unwrap(
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
)
raise NotImplementedError(reduction_type)
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: