From eee2aa14a645042a03ba963133b60c8aee1ff57a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 31 May 2022 07:26:03 -0700 Subject: [PATCH] Register std_mean ref as a decomposition Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/78468 Approved by: https://github.com/ngimel --- test/test_decomp.py | 12 ++++++------ test/test_meta.py | 2 -- third_party/ideep | 2 +- torch/_refs/__init__.py | 9 +++++++-- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/test/test_decomp.py b/test/test_decomp.py index 1417949032b..7b7bb6222bf 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -144,12 +144,12 @@ def _getDefaultRtolAndAtol(dtype0, dtype1): return rtol, atol -def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs): - assert orig.dtype == decomp.dtype, f"Operation: {op}" +def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): + assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" if orig.numel() == 0 or decomp.numel() == 0: assert orig.numel() == decomp.numel() return - assert orig.shape == decomp.shape, f"Operation: {op}" + assert orig.shape == decomp.shape, f"{i} Operation: {op}" tol_table = { (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5, (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, @@ -163,7 +163,7 @@ def op_assert_ref(test_case, op, test_dtype, orig, decomp, ref, args, kwargs): if decomp_diff > orig_diff + atol: raise RuntimeError( f"Difference from float64 is larger with decomposition {op.__name__}" - f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" + f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" f"atol = {atol}\n" f"args = {args}\n" f"kwargs = {kwargs}" @@ -414,11 +414,11 @@ class TestDecomp(TestCase): real_out_double, _ = tree_flatten( func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) ) - for orig, decomp, ref in zip(real_out, decomp_out, real_out_double): + for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double): if orig is None: assert decomp is None continue - op_assert_ref(self, func, test_dtype, orig, decomp, ref, args, kwargs) + op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs) else: for orig, decomp in zip(real_out, decomp_out): if orig is None: diff --git a/test/test_meta.py b/test/test_meta.py index c2b66f08aa5..93685f175f7 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -583,7 +583,6 @@ meta_function_expected_failures = { torch.roll: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::roll torch.searchsorted: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::searchsorted.Tensor, aten::searchsorted.Tensor_out torch.symeig: {f32, f64}, - torch.std_mean: {bf16, f16, f32, f64}, # aten::std_mean.correction torch.take: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::take, aten::take.out torch.trace: {f32, f64, i16, i32, i64, i8, u8}, # aten::trace torch.vdot: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::vdot @@ -848,7 +847,6 @@ meta_dispatch_expected_failures = { aten.rrelu_with_noise.default: {bf16, f64, f32}, aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, aten.searchsorted.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32}, - aten.std_mean.correction: {bf16, f16, f64, f32}, aten.take.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, aten.take.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32}, aten.tensordot.out: {i64, bf16, u8, f32, i8, f64, i16, i32}, diff --git a/third_party/ideep b/third_party/ideep index 8a114a51c11..02b17c5748c 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 8a114a51c116b55c4ceb689b98746786bd00c29b +Subproject commit 02b17c5748c9349dcc586c359af800c684d9b1ab diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index b85f7bc9e20..cd31693bb45 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1414,18 +1414,22 @@ def std( if dim == () or dim == []: dim = None + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + result = _reduction( a, partial(prims.var, correction=correction), dims=dim, keepdims=keepdim, - dtype=None, + dtype=opmath_dtype, out=None, has_identity=True, output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, ) result = sqrt(result) - return result + return _maybe_convert_to_dtype(result, dtype) # type: ignore[return-value,arg-type] def mean( @@ -1469,6 +1473,7 @@ def mean( return result +@register_decomposition(torch.ops.aten.std_mean.correction) def std_mean( a: TensorLikeType, dim: Union[Optional[int], Optional[List[int]]] = None,