Register std_mean ref as a decomposition

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78468

Approved by: https://github.com/ngimel
This commit is contained in:
Edward Z. Yang 2022-05-31 07:26:03 -07:00 committed by PyTorch MergeBot
parent 523c9c2ac2
commit eee2aa14a6
4 changed files with 14 additions and 11 deletions

View file

@ -144,12 +144,12 @@ def _getDefaultRtolAndAtol(dtype0, dtype1):
return rtol, atol
def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"Operation: {op}"
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
if orig.numel() == 0 or decomp.numel() == 0:
assert orig.numel() == decomp.numel()
return
assert orig.shape == decomp.shape, f"Operation: {op}"
assert orig.shape == decomp.shape, f"{i} Operation: {op}"
tol_table = {
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
@ -163,7 +163,7 @@ def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs):
if decomp_diff > orig_diff + atol:
raise RuntimeError(
f"Difference from float64 is larger with decomposition {op.__name__}"
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f"atol = {atol}\n"
f"args = {args}\n"
f"kwargs = {kwargs}"
@ -414,11 +414,11 @@ class TestDecomp(TestCase):
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double):
if orig is None:
assert decomp is None
continue
op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs)
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if orig is None:

View file

@ -583,7 +583,6 @@ meta_function_expected_failures = {
torch.roll: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::roll
torch.searchsorted: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::searchsorted.Tensor, aten::searchsorted.Tensor_out
torch.symeig: {f32, f64},
torch.std_mean: {bf16, f16, f32, f64}, # aten::std_mean.correction
torch.take: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::take, aten::take.out
torch.trace: {f32, f64, i16, i32, i64, i8, u8}, # aten::trace
torch.vdot: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::vdot
@ -848,7 +847,6 @@ meta_dispatch_expected_failures = {
aten.rrelu_with_noise.default: {bf16, f64, f32},
aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
aten.searchsorted.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
aten.std_mean.correction: {bf16, f16, f64, f32},
aten.take.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.take.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.tensordot.out: {i64, bf16, u8, f32, i8, f64, i16, i32},

2
third_party/ideep vendored

@ -1 +1 @@
Subproject commit 8a114a51c116b55c4ceb689b98746786bd00c29b
Subproject commit 02b17c5748c9349dcc586c359af800c684d9b1ab

View file

@ -1414,18 +1414,22 @@ def std(
if dim == () or dim == []:
dim = None
opmath_dtype, dtype = utils.reduction_dtypes(
a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
)
result = _reduction(
a,
partial(prims.var, correction=correction),
dims=dim,
keepdims=keepdim,
dtype=None,
dtype=opmath_dtype,
out=None,
has_identity=True,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
)
result = sqrt(result)
return result
return _maybe_convert_to_dtype(result, dtype) # type: ignore[return-value,arg-type]
def mean(
@ -1469,6 +1473,7 @@ def mean(
return result
@register_decomposition(torch.ops.aten.std_mean.correction)
def std_mean(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,