diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 2f5f4d8ce2e..80d311a16dc 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -520,6 +520,41 @@ def log10(a): return prims.log10(a) +@out_wrapper +def log_softmax( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(a.dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value] + + +@out_wrapper +def logsumexp( + a: TensorLikeType, + dim: DimsType, + keepdim: bool = False, +) -> TensorLikeType: + dim = utils.canonicalize_dims(a.ndim, dim) + # ATen specifies int[1] type dims which expands integers to tuples of length 1 + if not isinstance(dim, Iterable): + dim = (dim,) + if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): + # For float and complex dtypes, we shift input to exp by a constant to avoid overflow + a_max = amax(a, dim, keepdim=True) + a_max = where(abs(a_max) == float("inf"), 0.0, a_max) + a_max_squeezed = prims.squeeze(a_max, dim) if not keepdim else a_max + result = log(sum(exp(a - a_max), dim, keepdim=keepdim)) + a_max_squeezed + else: + # This case covers boolean and integer dtypes and we use non-stabilized computation + result = log(sum(exp(a), dim, keepdim=keepdim)) + return result + + @register_decomposition(torch.ops.aten.nan_to_num) @out_wrapper @elementwise_type_promotion_wrapper( @@ -1350,7 +1385,7 @@ def amin( def amax( a: TensorLikeType, - dim: Union[Optional[int], Optional[List[int]]] = None, + dim: Optional[DimsType] = None, keepdim: bool = False, *, out: Optional[Tensor] = None, @@ -2017,6 +2052,24 @@ def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) +@out_wrapper +def softmax( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(a.dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + assert isinstance(a_, TensorLike) # to avoid MyPy error for amax + a_max = amax(a_, dim, keepdim=True) + a_exp = exp(a_ - a_max) + return _maybe_convert_to_dtype( + true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype + ) # type: ignore[return-value] + + @out_wrapper def hstack(tensors: TensorSequenceType) -> TensorLikeType: check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f0b25619088..67690afe2d4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3345,6 +3345,7 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs): ((S, S), (1,), True), ((S, S), (1,), False), ((S, S), (-2,), False), + ((S, S), (0, 1), False), ) samples = [] # Test large inputs to check numerical stability @@ -17534,6 +17535,7 @@ op_db: List[OpInfo] = [ assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, + gradcheck_fast_mode=False, sample_inputs_func=sample_inputs_logsumexp), OpInfo('trace', dtypes=all_types_and_complex(), @@ -20064,6 +20066,22 @@ python_ref_db = [ "_refs.log2", torch_opinfo_name="log2", ), + PythonRefInfo( + "_refs.logsumexp", + torch_opinfo_name="logsumexp", + skips=( + # SyntaxError: cannot assign to False + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + ), + ), + PythonRefInfo( + "_refs.log_softmax", + torch_opinfo_name="log_softmax", + skips=( + # RuntimeError: Tracing expected 3 arguments but got 2 concrete arguments + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + ), + ), ElementwiseUnaryPythonRefInfo( "_refs.nan_to_num", torch_opinfo_name="nan_to_num", @@ -20122,6 +20140,14 @@ python_ref_db = [ "_refs.sinh", torch_opinfo_name="sinh", ), + PythonRefInfo( + "_refs.softmax", + torch_opinfo_name="softmax", + skips=( + # RuntimeError: Tracing expected 3 arguments but got 2 concrete arguments + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + ), + ), ElementwiseUnaryPythonRefInfo( "_refs.sqrt", torch_opinfo_name="sqrt",