mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Register std_mean ref as a decomposition
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/78468 Approved by: https://github.com/ngimel
This commit is contained in:
parent
523c9c2ac2
commit
eee2aa14a6
4 changed files with 14 additions and 11 deletions
|
|
@ -144,12 +144,12 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
|
|||
return rtol, atol
|
||||
|
||||
|
||||
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
|
||||
assert orig.dtype == decomp.dtype, f"Operation: {op}"
|
||||
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
|
||||
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
|
||||
if orig.numel() == 0 or decomp.numel() == 0:
|
||||
assert orig.numel() == decomp.numel()
|
||||
return
|
||||
assert orig.shape == decomp.shape, f"Operation: {op}"
|
||||
assert orig.shape == decomp.shape, f"{i} Operation: {op}"
|
||||
tol_table = {
|
||||
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
|
||||
|
|
@ -163,7 +163,7 @@ def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
|
|||
if decomp_diff > orig_diff + atol:
|
||||
raise RuntimeError(
|
||||
f"Difference from float64 is larger with decomposition {op.__name__}"
|
||||
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
|
||||
f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
|
||||
f"atol = {atol}\n"
|
||||
f"args = {args}\n"
|
||||
f"kwargs = {kwargs}"
|
||||
|
|
@ -414,11 +414,11 @@ class TestDecomp(TestCase):
|
|||
real_out_double, _ = tree_flatten(
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
|
||||
for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double):
|
||||
if orig is None:
|
||||
assert decomp is None
|
||||
continue
|
||||
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
|
||||
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if orig is None:
|
||||
|
|
|
|||
|
|
@ -583,7 +583,6 @@ meta_function_expected_failures = {
|
|||
torch.roll: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::roll
|
||||
torch.searchsorted: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::searchsorted.Tensor, aten::searchsorted.Tensor_out
|
||||
torch.symeig: {f32, f64},
|
||||
torch.std_mean: {bf16, f16, f32, f64}, # aten::std_mean.correction
|
||||
torch.take: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::take, aten::take.out
|
||||
torch.trace: {f32, f64, i16, i32, i64, i8, u8}, # aten::trace
|
||||
torch.vdot: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::vdot
|
||||
|
|
@ -848,7 +847,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.rrelu_with_noise.default: {bf16, f64, f32},
|
||||
aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
||||
aten.searchsorted.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
||||
aten.std_mean.correction: {bf16, f16, f64, f32},
|
||||
aten.take.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten.take.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten.tensordot.out: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
||||
|
|
|
|||
2
third_party/ideep
vendored
2
third_party/ideep
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit 8a114a51c116b55c4ceb689b98746786bd00c29b
|
||||
Subproject commit 02b17c5748c9349dcc586c359af800c684d9b1ab
|
||||
|
|
@ -1414,18 +1414,22 @@ def std(
|
|||
if dim == () or dim == []:
|
||||
dim = None
|
||||
|
||||
opmath_dtype, dtype = utils.reduction_dtypes(
|
||||
a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
|
||||
)
|
||||
|
||||
result = _reduction(
|
||||
a,
|
||||
partial(prims.var, correction=correction),
|
||||
dims=dim,
|
||||
keepdims=keepdim,
|
||||
dtype=None,
|
||||
dtype=opmath_dtype,
|
||||
out=None,
|
||||
has_identity=True,
|
||||
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
|
||||
)
|
||||
result = sqrt(result)
|
||||
return result
|
||||
return _maybe_convert_to_dtype(result, dtype) # type: ignore[return-value,arg-type]
|
||||
|
||||
|
||||
def mean(
|
||||
|
|
@ -1469,6 +1473,7 @@ def mean(
|
|||
return result
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.std_mean.correction)
|
||||
def std_mean(
|
||||
a: TensorLikeType,
|
||||
dim: Union[Optional[int], Optional[List[int]]] = None,
|
||||
|
|
|
|||
Loading…
Reference in a new issue