mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPSInductor] Implement Welford reduction (#146703)
Still work in progress, though fallback works as expected, but custom shader is not Pull Request resolved: https://github.com/pytorch/pytorch/pull/146703 Approved by: https://github.com/jansel, https://github.com/dcci
This commit is contained in:
parent
69feef5a94
commit
2328dcccb9
3 changed files with 24 additions and 1 deletions
|
|
@ -29,6 +29,18 @@ opmath_t<T> threadgroup_prod(threadgroup T* data, unsigned size) {
|
|||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
float m = data[0];
|
||||
float m2 = 0;
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
float delta = data[idx] - m;
|
||||
m += delta / (idx + 1);
|
||||
m2 += delta * (data[idx] - m);
|
||||
}
|
||||
return float2(m, m2);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_max(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
|
|
|
|||
|
|
@ -164,6 +164,7 @@ for test_name in [
|
|||
"test_inf",
|
||||
"test_isinf",
|
||||
"test_isinf2",
|
||||
"test_layer_norm",
|
||||
"test_lgamma",
|
||||
"test_linear_float64",
|
||||
"test_log_fp64",
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_
|
|||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from ..utils import get_bounds_index_expr, get_kernel_metadata
|
||||
from ..virtualized import ops, V
|
||||
from ..virtualized import ops, OpsWrapper, V
|
||||
from .common import (
|
||||
CSEVariable,
|
||||
DeferredLine,
|
||||
|
|
@ -463,6 +463,16 @@ class MetalKernel(SIMDKernel):
|
|||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
|
||||
dtype=dtype,
|
||||
)
|
||||
if reduction_type == "welford_reduce":
|
||||
acc_buf = self._new_accvar(src_dtype, reduction_dim.numel)
|
||||
self.body.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
|
||||
wf_res = self.cse.generate(
|
||||
self.body,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {reduction_dim.numel})",
|
||||
)
|
||||
return OpsWrapper._unwrap(
|
||||
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
|
||||
)
|
||||
raise NotImplementedError(reduction_type)
|
||||
|
||||
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue