Reference implementations for softmax, log_softmax, logsumexp (#79423)

This PR adds references for:

- `torch.softmax`
- `torch.log_softmax`
- `torch.logsumexp`

Unfortunately, none of them currently pass `test_python_ref_executor` even with `"aten"` executor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79423
Approved by: https://github.com/mruberry
This commit is contained in:
Ivan Yashchuk 2022-06-14 19:43:51 +00:00 committed by PyTorch MergeBot
parent 8895862744
commit 4fc7832d72
2 changed files with 80 additions and 1 deletions

View file

@ -520,6 +520,41 @@ def log10(a):
return prims.log10(a)
@out_wrapper
def log_softmax(
a: TensorLikeType,
dim: int,
*,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
result_dtype = dtype or a.dtype
computation_dtype = utils.get_computation_dtype(a.dtype)
a_ = _maybe_convert_to_dtype(a, computation_dtype)
return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
@out_wrapper
def logsumexp(
a: TensorLikeType,
dim: DimsType,
keepdim: bool = False,
) -> TensorLikeType:
dim = utils.canonicalize_dims(a.ndim, dim)
# ATen specifies int[1] type dims which expands integers to tuples of length 1
if not isinstance(dim, Iterable):
dim = (dim,)
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
# For float and complex dtypes, we shift input to exp by a constant to avoid overflow
a_max = amax(a, dim, keepdim=True)
a_max = where(abs(a_max) == float("inf"), 0.0, a_max)
a_max_squeezed = prims.squeeze(a_max, dim) if not keepdim else a_max
result = log(sum(exp(a - a_max), dim, keepdim=keepdim)) + a_max_squeezed
else:
# This case covers boolean and integer dtypes and we use non-stabilized computation
result = log(sum(exp(a), dim, keepdim=keepdim))
return result
@register_decomposition(torch.ops.aten.nan_to_num)
@out_wrapper
@elementwise_type_promotion_wrapper(
@ -1350,7 +1385,7 @@ def amin(
def amax(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
dim: Optional[DimsType] = None,
keepdim: bool = False,
*,
out: Optional[Tensor] = None,
@ -2017,6 +2052,24 @@ def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
@out_wrapper
def softmax(
a: TensorLikeType,
dim: int,
*,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
result_dtype = dtype or a.dtype
computation_dtype = utils.get_computation_dtype(a.dtype)
a_ = _maybe_convert_to_dtype(a, computation_dtype)
assert isinstance(a_, TensorLike) # to avoid MyPy error for amax
a_max = amax(a_, dim, keepdim=True)
a_exp = exp(a_ - a_max)
return _maybe_convert_to_dtype(
true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
) # type: ignore[return-value]
@out_wrapper
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")

View file

@ -3345,6 +3345,7 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
((S, S), (1,), True),
((S, S), (1,), False),
((S, S), (-2,), False),
((S, S), (0, 1), False),
)
samples = []
# Test large inputs to check numerical stability
@ -17534,6 +17535,7 @@ op_db: List[OpInfo] = [
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
gradcheck_fast_mode=False,
sample_inputs_func=sample_inputs_logsumexp),
OpInfo('trace',
dtypes=all_types_and_complex(),
@ -20064,6 +20066,22 @@ python_ref_db = [
"_refs.log2",
torch_opinfo_name="log2",
),
PythonRefInfo(
"_refs.logsumexp",
torch_opinfo_name="logsumexp",
skips=(
# SyntaxError: cannot assign to False
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
),
),
PythonRefInfo(
"_refs.log_softmax",
torch_opinfo_name="log_softmax",
skips=(
# RuntimeError: Tracing expected 3 arguments but got 2 concrete arguments
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.nan_to_num",
torch_opinfo_name="nan_to_num",
@ -20122,6 +20140,14 @@ python_ref_db = [
"_refs.sinh",
torch_opinfo_name="sinh",
),
PythonRefInfo(
"_refs.softmax",
torch_opinfo_name="softmax",
skips=(
# RuntimeError: Tracing expected 3 arguments but got 2 concrete arguments
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.sqrt",
torch_opinfo_name="sqrt",