mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Reference implementations for softmax, log_softmax, logsumexp (#79423)
This PR adds references for: - `torch.softmax` - `torch.log_softmax` - `torch.logsumexp` Unfortunately, none of them currently pass `test_python_ref_executor` even with `"aten"` executor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79423 Approved by: https://github.com/mruberry
This commit is contained in:
parent
8895862744
commit
4fc7832d72
2 changed files with 80 additions and 1 deletions
|
|
@ -520,6 +520,41 @@ def log10(a):
|
|||
return prims.log10(a)
|
||||
|
||||
|
||||
@out_wrapper
|
||||
def log_softmax(
|
||||
a: TensorLikeType,
|
||||
dim: int,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> TensorLikeType:
|
||||
result_dtype = dtype or a.dtype
|
||||
computation_dtype = utils.get_computation_dtype(a.dtype)
|
||||
a_ = _maybe_convert_to_dtype(a, computation_dtype)
|
||||
return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
|
||||
|
||||
|
||||
@out_wrapper
|
||||
def logsumexp(
|
||||
a: TensorLikeType,
|
||||
dim: DimsType,
|
||||
keepdim: bool = False,
|
||||
) -> TensorLikeType:
|
||||
dim = utils.canonicalize_dims(a.ndim, dim)
|
||||
# ATen specifies int[1] type dims which expands integers to tuples of length 1
|
||||
if not isinstance(dim, Iterable):
|
||||
dim = (dim,)
|
||||
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
|
||||
# For float and complex dtypes, we shift input to exp by a constant to avoid overflow
|
||||
a_max = amax(a, dim, keepdim=True)
|
||||
a_max = where(abs(a_max) == float("inf"), 0.0, a_max)
|
||||
a_max_squeezed = prims.squeeze(a_max, dim) if not keepdim else a_max
|
||||
result = log(sum(exp(a - a_max), dim, keepdim=keepdim)) + a_max_squeezed
|
||||
else:
|
||||
# This case covers boolean and integer dtypes and we use non-stabilized computation
|
||||
result = log(sum(exp(a), dim, keepdim=keepdim))
|
||||
return result
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.nan_to_num)
|
||||
@out_wrapper
|
||||
@elementwise_type_promotion_wrapper(
|
||||
|
|
@ -1350,7 +1385,7 @@ def amin(
|
|||
|
||||
def amax(
|
||||
a: TensorLikeType,
|
||||
dim: Union[Optional[int], Optional[List[int]]] = None,
|
||||
dim: Optional[DimsType] = None,
|
||||
keepdim: bool = False,
|
||||
*,
|
||||
out: Optional[Tensor] = None,
|
||||
|
|
@ -2017,6 +2052,24 @@ def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
|||
return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
|
||||
|
||||
|
||||
@out_wrapper
|
||||
def softmax(
|
||||
a: TensorLikeType,
|
||||
dim: int,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> TensorLikeType:
|
||||
result_dtype = dtype or a.dtype
|
||||
computation_dtype = utils.get_computation_dtype(a.dtype)
|
||||
a_ = _maybe_convert_to_dtype(a, computation_dtype)
|
||||
assert isinstance(a_, TensorLike) # to avoid MyPy error for amax
|
||||
a_max = amax(a_, dim, keepdim=True)
|
||||
a_exp = exp(a_ - a_max)
|
||||
return _maybe_convert_to_dtype(
|
||||
true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
|
||||
) # type: ignore[return-value]
|
||||
|
||||
|
||||
@out_wrapper
|
||||
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
|
||||
check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
|
||||
|
|
|
|||
|
|
@ -3345,6 +3345,7 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
|
|||
((S, S), (1,), True),
|
||||
((S, S), (1,), False),
|
||||
((S, S), (-2,), False),
|
||||
((S, S), (0, 1), False),
|
||||
)
|
||||
samples = []
|
||||
# Test large inputs to check numerical stability
|
||||
|
|
@ -17534,6 +17535,7 @@ op_db: List[OpInfo] = [
|
|||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
gradcheck_fast_mode=False,
|
||||
sample_inputs_func=sample_inputs_logsumexp),
|
||||
OpInfo('trace',
|
||||
dtypes=all_types_and_complex(),
|
||||
|
|
@ -20064,6 +20066,22 @@ python_ref_db = [
|
|||
"_refs.log2",
|
||||
torch_opinfo_name="log2",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.logsumexp",
|
||||
torch_opinfo_name="logsumexp",
|
||||
skips=(
|
||||
# SyntaxError: cannot assign to False
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.log_softmax",
|
||||
torch_opinfo_name="log_softmax",
|
||||
skips=(
|
||||
# RuntimeError: Tracing expected 3 arguments but got 2 concrete arguments
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
|
||||
),
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.nan_to_num",
|
||||
torch_opinfo_name="nan_to_num",
|
||||
|
|
@ -20122,6 +20140,14 @@ python_ref_db = [
|
|||
"_refs.sinh",
|
||||
torch_opinfo_name="sinh",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.softmax",
|
||||
torch_opinfo_name="softmax",
|
||||
skips=(
|
||||
# RuntimeError: Tracing expected 3 arguments but got 2 concrete arguments
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
|
||||
),
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.sqrt",
|
||||
torch_opinfo_name="sqrt",
|
||||
|
|
|
|||
Loading…
Reference in a new issue