pytorch/torch/_refs/__init__.py
Ivan Yashchuk 4fc7832d72 Reference implementations for softmax, log_softmax, logsumexp (#79423)
This PR adds references for:

- `torch.softmax`
- `torch.log_softmax`
- `torch.logsumexp`

Unfortunately, none of them currently pass `test_python_ref_executor` even with `"aten"` executor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79423
Approved by: https://github.com/mruberry
2022-06-14 19:43:51 +00:00

2495 lines
74 KiB
Python

import torch
import torch._prims as prims
import torch._prims.utils as utils
from torch._prims.utils import (
check,
DimsType,
ShapeType,
StrideType,
TensorLike,
TensorLikeType,
DeviceLikeType,
TensorOrNumberLikeType,
DimsSequenceType,
TensorSequenceType,
Number,
NumberType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
REDUCTION_OUTPUT_TYPE_KIND,
is_weakly_lesser_type,
dtype_to_type,
)
from torch._prims.wrappers import (
elementwise_type_promotion_wrapper,
out_wrapper,
_maybe_convert_to_dtype,
_maybe_resize_out,
elementwise_unary_scalar_wrapper,
_safe_copy_out,
)
from collections.abc import Iterable
from functools import reduce, partial, wraps
from typing import Sequence, Optional, Union, Callable, List, Tuple
import operator
import warnings
import math
from enum import Enum
import collections
# Experimental module containing prototype Python references for existing
# PyTorch operations.
__all__ = [
#
# Elementwise Unary References
#
"abs",
"acos",
"acosh",
"asin",
"atan",
"bitwise_not",
# "cbrt", # No corresponding torch operation
"ceil",
"cos",
"cosh",
"digamma",
"erf",
"erfinv",
"erfc",
"exp",
"expm1",
"exp2",
"fill",
"floor",
"frac",
"isfinite",
"isinf",
"isnan",
"i0",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"nan_to_num",
"neg",
"positive",
"reciprocal",
"round", # TODO: model kwargs
"sigmoid",
"sign",
"signbit",
"sin",
"sinh",
"sqrt",
"square",
"tan",
"tanh",
#
# Elementwise Binary References
#
"add",
"atan2",
"bitwise_and",
"bitwise_left_shift",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
# "complex",
# 'copysign', # where
# 'div', # need to implement all rounding modes first
"eq",
"float_power",
# 'floor_divide', # requires floor
"fmax",
"fmin",
"fmod",
# 'gcd',
"ge",
"gt",
# 'heaviside',
# 'hypot',
"igamma",
"igammac",
"isclose",
# 'lcm',
# 'ldexp',
"le",
"logical_and",
"logical_or",
"logical_xor",
"lt",
# 'max', # implement with reductions
"maximum",
# 'min', # implement with reductions
"minimum",
"mul",
"ne",
"nextafter",
# 'polar', # abs, cos, sin
"pow",
# 'remainder',
# 'rsub', # unblocked
# # special.xlog1py
# # special.zeta
"sub",
"true_divide",
# 'xlogy', # where?, log, mul
#
# Elementwise Ternary References
#
"clamp",
#
# Conditional references
#
"masked_fill",
"where",
#
# Data conversion and movement references
#
"clone",
"copy_to", # TODO: add OpInfo (or implement .to)
"item", # TODO: add OpInfo
#
# Reduction ops
#
"all",
"amax",
"amin",
"any",
"mean",
"std_mean",
"var_mean",
"sum",
"prod",
"var",
#
# Linear algebra ops
#
"addr",
#
# View & Shape Ops
#
"atleast_1d",
"atleast_2d",
"atleast_3d",
"as_strided",
"broadcast_shapes",
"broadcast_tensors",
"broadcast_to",
"cat",
"chunk",
"column_stack",
"dsplit",
"dstack",
"flatten",
"flip",
"fliplr",
"flipud",
"hsplit",
"hstack",
"narrow",
"permute",
"ravel",
"reshape",
"roll",
"rot90",
"stack",
"swap_axes", # alias for transpose
"squeeze",
"t",
"tensor_split",
"transpose",
"unsqueeze",
"view",
"vsplit",
"vstack",
#
# Tensor Creation
#
"empty",
"empty_like",
"empty_strided",
"full",
"full_like",
"ones",
"ones_like",
"zeros",
"zeros_like",
#
# Randomness References
#
"uniform", # TODO: add OpInfo -- and testing for randomness?
#
# Test-related functions
#
"equal", # TODO: add OpInfo
]
Tensor = torch.Tensor
def _broadcast_shapes(*_shapes):
shapes = tuple(
(x,) if isinstance(x, int) else x
for x in filter(lambda x: x is not None, _shapes)
)
# Short-circuits on no input
if len(shapes) == 0:
return None
# Type checking
# TODO: make common validations available as utils
for shape in shapes:
assert isinstance(shape, Sequence)
# Computes common shape
common_shape = [
1,
] * reduce(max, (len(shape) for shape in shapes))
for shape in shapes:
for idx in range(-1, -1 - len(shape), -1):
if common_shape[idx] == 1:
if shape[idx] < 0:
raise ValueError(
"Attempting to broadcast a dimension with negative length!"
)
common_shape[idx] = shape[idx]
elif shape[idx] != 1:
if common_shape[idx] != shape[idx]:
raise RuntimeError(
"Attempting to broadcast a dimension of length ",
str(shape[idx]),
"!",
)
return common_shape
def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
# Computes common shape
common_shape = _broadcast_shapes(
*map(lambda t: t.shape if isinstance(t, TensorLike) else None, args)
)
def __maybe_broadcast(x, shape):
if x is None:
return None
elif isinstance(x, Number):
return x
elif isinstance(x, TensorLike):
if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
return x
if tuple(x.shape) != common_shape:
common_rank = len(common_shape) + 1
start = common_rank - (len(x.shape) + 1)
dims = tuple(range(start, len(x.shape) + start))
return prims.broadcast_in_dim(x, common_shape, dims)
else:
raise RuntimeError(
"Unexpected type when broadcasting: " + str(type(x)) + "!"
)
return tuple(__maybe_broadcast(x, common_shape) for x in args)
# Utilities should come BEFORE this import
from torch._decomp import register_decomposition
#
# Elementwise unary references
#
infer_aten_op = object()
# TODO: add type promotion support
def _make_elementwise_unary_reference(
type_promotion_kind,
*,
aten_op=infer_aten_op,
disable_meta=False,
extra_meta=None,
) -> Callable:
def inner(prim: Callable):
nonlocal aten_op
@wraps(prim)
@out_wrapper
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=type_promotion_kind,
)
def _ref(a: TensorLikeType) -> TensorLikeType:
if not isinstance(a, TensorLike):
raise RuntimeError(
"Expected a tensor input for an elementwise unary operation!"
)
if extra_meta is not None:
extra_meta(a)
return prim(a)
if aten_op is infer_aten_op:
aten_op = getattr(torch.ops.aten, prim.__name__)
if aten_op is not None:
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
return _ref
return inner
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
def abs(a):
return prims.abs(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def acos(a):
return prims.acos(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def acosh(a):
return prims.acosh(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def asin(a):
return prims.asin(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def atan(a):
return prims.atan(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def bitwise_not(a):
return prims.bitwise_not(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def ceil(a):
return prims.ceil(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def cos(a):
return prims.cos(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def cosh(a):
return prims.cosh(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def digamma(a):
return prims.digamma(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def erf(a):
return prims.erf(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def erfinv(a):
return prims.erf_inv(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def erfc(a):
return prims.erfc(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def exp(a):
return prims.exp(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def expm1(a):
return prims.expm1(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def exp2(a):
return prims.exp2(a)
# Fill has its own implementation because it has a value parameter
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a,"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
assert isinstance(a, TensorLike)
assert isinstance(value, Number)
python_type = utils.dtype_to_type(a.dtype)
if not utils.is_weakly_lesser_type(type(value), python_type):
msg = "value argument of type {0} cannot be safely cast to type {1}!".format(
type(value), python_type
)
raise ValueError(msg)
return prims.fill(a, value)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def floor(a):
return prims.floor(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def frac(x: TensorLikeType) -> TensorLikeType:
trunc_x = mul(floor(abs(x)), sign(x))
return sub(x, trunc_x)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=None, # CompositeImplicitAutograd
)
def isfinite(a: TensorLikeType) -> TensorLikeType:
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
return prims.isfinite(a)
return ones_like(a, dtype=torch.bool)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isinf(a: TensorLikeType) -> TensorLikeType:
# TODO Add complex tensor support to remove is_infinite prim
# if utils.is_complex_dtype(a):
# return bitwise_or(_isinf(real(a), _isinf(imag(a))
# else:
# return bitwise_not(bitwise_or(isnan(a), isfinite(a)))
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
return prims.is_infinite(a)
return zeros_like(a, dtype=torch.bool)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isnan(a: TensorLikeType) -> TensorLikeType:
return prims.ne(a, a)
# TODO: if this is special maybe it should be defined there and imported here?
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0
)
def i0(a):
return prims.bessel_i0(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def lgamma(a):
return prims.lgamma(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log(a):
return prims.log(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log1p(a):
return prims.log1p(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log2(a):
return prims.log2(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log10(a):
return prims.log10(a)
@out_wrapper
def log_softmax(
a: TensorLikeType,
dim: int,
*,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
result_dtype = dtype or a.dtype
computation_dtype = utils.get_computation_dtype(a.dtype)
a_ = _maybe_convert_to_dtype(a, computation_dtype)
return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
@out_wrapper
def logsumexp(
a: TensorLikeType,
dim: DimsType,
keepdim: bool = False,
) -> TensorLikeType:
dim = utils.canonicalize_dims(a.ndim, dim)
# ATen specifies int[1] type dims which expands integers to tuples of length 1
if not isinstance(dim, Iterable):
dim = (dim,)
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
# For float and complex dtypes, we shift input to exp by a constant to avoid overflow
a_max = amax(a, dim, keepdim=True)
a_max = where(abs(a_max) == float("inf"), 0.0, a_max)
a_max_squeezed = prims.squeeze(a_max, dim) if not keepdim else a_max
result = log(sum(exp(a - a_max), dim, keepdim=keepdim)) + a_max_squeezed
else:
# This case covers boolean and integer dtypes and we use non-stabilized computation
result = log(sum(exp(a), dim, keepdim=keepdim))
return result
@register_decomposition(torch.ops.aten.nan_to_num)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a,"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def nan_to_num(
a: TensorLikeType,
*,
nan: Optional[NumberType] = 0.0,
posinf: Optional[NumberType] = None,
neginf: Optional[NumberType] = None,
) -> TensorLikeType:
assert isinstance(a, TensorLike)
if a.dtype == torch.bool:
return clone(a)
if posinf is None:
posinf = prims.maximum_value(a.dtype)
if neginf is None:
neginf = prims.minimum_value(a.dtype)
result = where(isnan(a), nan, a)
is_neg = signbit(a)
is_neginf = bitwise_and(isinf(a), is_neg)
result = where(is_neginf, neginf, result)
is_posinf = bitwise_and(isinf(a), bitwise_not(is_neg))
result = where(is_posinf, posinf, result)
return result
def _neg_meta(a: TensorLikeType):
if a.dtype is torch.bool:
msg = "neg is not supported on bool tensors."
raise RuntimeError(msg)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
)
def neg(a):
return prims.neg(a)
# positive does not use _make_elementwise_unary_reference because it does not support out
def positive(a: TensorLikeType) -> TensorLikeType:
assert isinstance(a, TensorLike)
if a.dtype is torch.bool:
msg = "positive does not support bool tensors."
raise RuntimeError(msg)
return a
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def reciprocal(a):
return prims.reciprocal(a)
# TODO: round takes additional kwargs
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=None, # TODO: this does need a decomp, but kwarg handling is needed
)
def round(a):
return prims.round(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sigmoid(a: TensorLikeType) -> TensorLikeType:
return true_divide(1, add(1, exp(neg(a))))
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def sign(a):
return prims.sign(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def signbit(a):
return prims.signbit(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sin(a):
return prims.sin(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sinh(a):
return prims.sinh(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sqrt(a):
return prims.sqrt(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
aten_op=None, # CompositeImplicitAutograd,
)
def square(a: TensorLikeType) -> TensorLikeType:
return mul(a, a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def tan(a):
return prims.tan(a)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def tanh(a):
return prims.tanh(a)
def _make_elementwise_binary_reference(
prim: Callable,
*,
type_promotion_kind,
aten_op=infer_aten_op,
has_out=True,
supports_lhs_python_scalar=True,
supports_rhs_python_scalar=True,
disable_meta=False,
) -> Callable:
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
type_promotion_kind=type_promotion_kind,
)
def _ref(
a: Union[Tensor, NumberType],
b: Union[Tensor, NumberType],
) -> Tensor:
if not supports_lhs_python_scalar and isinstance(a, Number):
raise ValueError(
"Received a lhs Python scalar to an elementwise binary operation that does not accept lhs scalars!"
)
if not supports_rhs_python_scalar and isinstance(b, Number):
raise ValueError(
"Received a rhs Python scalar to an elementwise binary operation that does not accept rhs scalars!"
)
# TODO: enable this for operations that support it, like add
if isinstance(a, Number) and isinstance(b, Number):
raise ValueError(
"Receive two Number inputs to an elementwise binary operation!"
)
a, b = _maybe_broadcast(a, b)
return prim(a, b)
if has_out:
_ref = out_wrapper(_ref)
if aten_op is infer_aten_op:
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0])
if aten_op is not None:
register_decomposition(aten_op, disable_meta=disable_meta)(_ref)
return _ref
# Add has its own implementation because it has an alpha argument
@register_decomposition(torch.ops.aten.add)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def add(
a: Union[TensorLikeType, NumberType],
b: Union[TensorLikeType, NumberType],
*,
alpha: Optional[NumberType] = None,
):
"""
Reference implementation of torch.add
"""
if isinstance(a, Number) and isinstance(b, Number):
raise ValueError(
"Receive two Number inputs to an elementwise binary operation!"
)
a, b = _maybe_broadcast(a, b)
if alpha is not None:
dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
python_type = utils.dtype_to_type(dtype)
if not utils.is_weakly_lesser_type(type(alpha), python_type):
msg = (
"alpha argument of type {0} cannot be safely cast to type {1}!".format(
type(alpha), python_type
)
)
raise ValueError(msg)
b = prims.mul(b, alpha)
return prims.add(a, b)
# TODO: add docstring
atan2 = _make_elementwise_binary_reference(
prims.atan2,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
)
# TODO: add docstring
bitwise_and = _make_elementwise_binary_reference(
prims.bitwise_and,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
bitwise_left_shift = _make_elementwise_binary_reference(
prims.shift_left,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.bitwise_left_shift, # prim/aten name mismatch
)
# TODO: add docstring
bitwise_or = _make_elementwise_binary_reference(
prims.bitwise_or,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
bitwise_right_shift = _make_elementwise_binary_reference(
prims.shift_right_arithmetic,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.bitwise_right_shift, # prim/aten name mismatch
)
# TODO: add docstring
bitwise_xor = _make_elementwise_binary_reference(
prims.bitwise_xor,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
# complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
# TODO: add docstring
eq = _make_elementwise_binary_reference(
prims.eq,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
# TODO: add docstring
# Float power has its own implementation because it has unique type promotion.
# NB: aten_op not registered because CompositeExplicitAutograd
@out_wrapper
def float_power(
a: Union[TensorLikeType, NumberType],
b: Union[TensorLikeType, NumberType],
) -> Tensor:
if isinstance(a, Number) and isinstance(b, Number):
raise ValueError(
"Receive two Number inputs to an elementwise binary operation!"
)
# Handles type promotion
dtype = utils.get_higher_dtype(a, b)
assert dtype is not None
if utils.is_complex_dtype(dtype):
dtype = torch.complex128
else:
dtype = torch.float64
# Float power has the following contiguous cast behavior to be
# consistent with its C++ impl
if isinstance(a, TensorLike) and a.dtype != dtype:
a = prims.to_dtype(a, dtype)
if isinstance(b, TensorLike) and b.dtype != dtype:
b = prims.to_dtype(b, dtype)
a, b = _maybe_broadcast(a, b)
return prims.pow(a, b)
# TODO: add docstring
fmax = _make_elementwise_binary_reference(
prims.fmax,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.fmax,
)
# TODO: add docstring
fmin = _make_elementwise_binary_reference(
prims.fmin,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.fmin,
)
# TODO: add docstring
fmod = _make_elementwise_binary_reference(
prims.fmod,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
aten_op=torch.ops.aten.fmod,
)
# TODO: add docstring
ge = _make_elementwise_binary_reference(
prims.ge,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
# TODO: add docstring
gt = _make_elementwise_binary_reference(
prims.gt,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
igamma = _make_elementwise_binary_reference(
prims.igamma,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
)
igammac = _make_elementwise_binary_reference(
prims.igammac,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
)
def isclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> TensorLikeType:
check(
a.dtype == b.dtype,
lambda: "torch.isclose: Attempting to compare tensors of different dtypes {0} and {1}!".format(
a.dtype, b.dtype
),
ValueError,
)
check(
rtol >= 0,
lambda: "torch.isclose: rtol must be greater than or equal to zero, but got {0}!".format(
rtol
),
)
check(
atol >= 0,
lambda: "torch.isclose: atol must be greater than or equal to zero, but got {0}!".format(
atol
),
)
close = eq(a, b)
if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
close = logical_or(close, logical_and(isnan(a), isnan(b)))
# Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
# In this case, the short-circuit prevents false positives as detailed in the paragraph below.
if atol == 0 and rtol == 0:
return close
# Note [closeness error computation]
# atol and rtol are provided as doubles, so the computation
# rtol * other will produce a float or complex tensor.
# When the difference (self - other) is compared to it then the
# tensor representing the difference will also be cast to float or complex.
# However, since (self - other) in uint8 is very likely to produce a
# negative value, this moves the cast forward so the difference is
# always computed in a float or complex type.
# If the values of the integer tensors cannot be exactly represented
# by the default scalar type then this may cause an incorrect result.
if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
a = prims.convert_element_type(a, torch.get_default_dtype())
b = prims.convert_element_type(b, torch.get_default_dtype())
allowed_error = add(atol, abs(mul(b, rtol)))
actual_error = abs(sub(a, b))
# Computes finite closeness
result = logical_or(
close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
)
return result
# TODO: add docstring
le = _make_elementwise_binary_reference(
prims.le,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
def _logical_and(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_and(a, b)
logical_and = _make_elementwise_binary_reference(
_logical_and,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_and,
)
def _logical_or(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_or(a, b)
logical_or = _make_elementwise_binary_reference(
_logical_or,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_or,
)
def _logical_xor(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_xor(a, b)
# TODO: skip unnecessary conversion of long to float
logical_xor = _make_elementwise_binary_reference(
_logical_xor,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_xor,
)
# TODO: add docstring
lt = _make_elementwise_binary_reference(
prims.lt,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
# TODO: add docstring
maximum = _make_elementwise_binary_reference(
prims.maximum,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
minimum = _make_elementwise_binary_reference(
prims.minimum,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
mul = _make_elementwise_binary_reference(
prims.mul,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: add docstring
ne = _make_elementwise_binary_reference(
prims.ne,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
supports_lhs_python_scalar=False,
)
# TODO: add docstring
nextafter = _make_elementwise_binary_reference(
prims.nextafter,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
supports_lhs_python_scalar=False,
supports_rhs_python_scalar=False,
)
# TODO: add docstring
pow = _make_elementwise_binary_reference(
prims.pow,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
)
# TODO: add docstring
# TODO: consider refactoring this with add impl
# sub has its own implementation because it has an alpha argument
@register_decomposition(torch.ops.aten.sub)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def sub(
a: Union[TensorLikeType, NumberType],
b: Union[TensorLikeType, NumberType],
*,
alpha: Optional[NumberType] = None,
):
"""
Reference implementation of torch.add
"""
if isinstance(a, Number) and isinstance(b, Number):
raise ValueError(
"Receive two Number inputs to an elementwise binary operation!"
)
a, b = _maybe_broadcast(a, b)
if alpha is not None:
dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
python_type = utils.dtype_to_type(dtype)
if not utils.is_weakly_lesser_type(type(alpha), python_type):
msg = (
"alpha argument of type {0} cannot be safely cast to type {1}!".format(
type(alpha), python_type
)
)
raise ValueError(msg)
b = prims.mul(b, alpha)
return prims.sub(a, b)
# TODO: add docstring
true_divide = _make_elementwise_binary_reference(
prims.div,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=None, # CompositeImplicitAutograd
)
#
# Elementwise Ternary References
#
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "min", "max"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def clamp(
a: TensorLikeType,
min: Optional[TensorOrNumberLikeType] = None,
max: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
a, min, max = _maybe_broadcast(a, min, max)
if min is not None and max is not None:
return minimum(maximum(a, min), max)
if min is not None:
return maximum(a, min)
if max is not None:
return minimum(a, max)
msg = "clamp called but both min and max are none!"
raise ValueError(msg)
#
# Conditional references
#
# https://pytorch.org/docs/stable/generated/torch.where.html
# TODO: implement alternate where
@register_decomposition(torch.ops.aten.where)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def where(
pred: Tensor,
a: Optional[TensorOrNumberLikeType] = None,
b: Optional[TensorOrNumberLikeType] = None,
):
""" """
if a is None or b is None:
raise NotImplementedError
utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
assert pred.dtype is torch.bool
pred, a, b = _maybe_broadcast(pred, a, b)
return prims.where(pred, a, b)
#
# Data Movement References
#
# TODO: Turn this into a decomposition (currently fails on reshape meta tests)
def clone(
a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
) -> TensorLikeType:
return prims.clone(a, memory_format=memory_format)
def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
if not allow_cross_device and a.device != b.device:
msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
b.device, a.device
)
raise RuntimeError(msg)
return prims.copy_to(a, b)
def item(a: TensorLikeType) -> NumberType:
if a.numel() != 1:
msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
raise ValueError(msg)
# NOTE: explicit conversion is necessary for bool!
# See https://github.com/pytorch/pytorch/issues/78071
number_type = utils.dtype_to_type(a.dtype)
return number_type(prims.item(a))
#
# Reduction references
#
def _reduction(
a: TensorLikeType,
prim: Callable,
*,
has_identity: bool = True,
accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only
dims: Optional[DimsType] = None,
keepdims: bool = False,
dtype: Optional[torch.dtype] = None, # should be specified for ops that support it
out: Optional[Tensor] = None,
output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
) -> TensorLikeType: # it is usually SAME, but I want
# ref writers to actually think about what to put here
assert isinstance(a, TensorLike)
if a.ndim > 64:
raise RuntimeError(
"Received a tensor with {0} dimensions, but only tensors with up to 64 dims are supported!".format(
a.ndim
)
)
if out is not None:
assert isinstance(out, TensorLike)
if dtype is not None:
# TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
if dtype != out.dtype:
raise RuntimeError(
"dtype argument and out dtype must match in reduction"
)
if not accepts_dim_tuple:
assert dims is None or isinstance(dims, int)
if isinstance(dims, int):
dims = (dims,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dims)
if not has_identity:
valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
if not valid_shape:
raise RuntimeError(
"reducing over zero-size dimension for reduction operation without identity"
)
computation_dtype, result_dtype = utils.reduction_dtypes(
a, output_dtype_kind, dtype
)
a_converted = prims.convert_element_type(a, computation_dtype)
result = prim(a_converted, dims)
if keepdims:
output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
broadcast_dims = [i for i in range(a.ndim) if i not in dims]
result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
if out is not None:
assert result_dtype is not None
if dtype is not None and result_dtype != out.dtype:
raise RuntimeError(
"Expected the dtype of reduction result and out to match"
)
out = _maybe_resize_out(out, result.shape)
return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
if result.dtype != result_dtype and result_dtype is not None:
result = prims.convert_element_type(result, result_dtype)
return result
# Saves Python all
py_all = all
@out_wrapper
def all(
a: TensorLikeType,
dim: Optional[DimsType] = None,
keepdim: bool = False,
) -> TensorLikeType:
# Computes nelem
if isinstance(dim, int):
dim = (dim,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
a_ = _maybe_convert_to_dtype(a, torch.bool)
result = eq(sum(a_, dim=dim, keepdim=keepdim), nelem) # type: ignore[arg-type]
# Preserves uint8 -- probably a legacy mask thing
if a.dtype is torch.uint8:
return prims.convert_element_type(result, torch.uint8)
return result
@out_wrapper
def any(
a: TensorLikeType,
dim: Optional[DimsType] = None,
keepdim: bool = False,
) -> TensorLikeType:
a_ = _maybe_convert_to_dtype(a, torch.bool)
result = ne(sum(a_, dim=dim, keepdim=keepdim), False) # type: ignore[arg-type]
# Preserves uint8 -- probably a legacy mask thing
if a.dtype is torch.uint8:
return prims.convert_element_type(result, torch.uint8)
return result
@register_decomposition(torch.ops.aten.sum)
def sum(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
keepdim: bool = False,
*,
dtype=None,
out: Optional[Tensor] = None,
) -> TensorLikeType:
if dtype is None:
if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
dtype = torch.int64
else:
dtype = a.dtype
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
return _reduction(
a,
prims.sum,
dims=dim,
keepdims=keepdim,
dtype=dtype,
out=out,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
)
@register_decomposition(torch.ops.aten.prod)
def prod(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
keepdim: bool = False,
*,
dtype=None,
out: Optional[Tensor] = None,
) -> TensorLikeType:
if dtype is None:
if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
dtype = torch.int64
else:
dtype = a.dtype
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
return _reduction(
a,
prims.prod,
dims=dim,
keepdims=keepdim,
dtype=dtype,
out=out,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
)
def amin(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
keepdim: bool = False,
*,
out: Optional[Tensor] = None,
) -> TensorLikeType:
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
return _reduction(
a,
prims.amin,
dims=dim,
keepdims=keepdim,
dtype=None,
out=out,
has_identity=False,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
)
def amax(
a: TensorLikeType,
dim: Optional[DimsType] = None,
keepdim: bool = False,
*,
out: Optional[Tensor] = None,
) -> TensorLikeType:
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
return _reduction(
a,
prims.amax,
dims=dim,
keepdims=keepdim,
dtype=None,
out=out,
has_identity=False,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
)
def _set_correction(
unbiased: Optional[bool] = None,
correction: Optional[int] = None,
):
if correction is not None and unbiased is not None:
raise RuntimeError("cannot specify both correction and unbiased arguments")
elif correction is None and unbiased is None:
correction = 1
elif correction is None and unbiased is not None:
correction = 0 if unbiased is False else 1
if not isinstance(correction, int):
raise ValueError("correction argument should be integer")
if correction < 0:
raise ValueError("correction argument should be non-negative")
return correction
@out_wrapper
def var(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: Optional[int] = None,
) -> TensorLikeType:
correction = _set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
result = _reduction(
a,
partial(prims.var, correction=correction),
dims=dim,
keepdims=keepdim,
dtype=None,
out=None,
has_identity=True,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
)
return result
@out_wrapper
def std(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: Optional[int] = None,
) -> TensorLikeType:
correction = _set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
opmath_dtype, dtype = utils.reduction_dtypes(
a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
)
result = _reduction(
a,
partial(prims.var, correction=correction),
dims=dim,
keepdims=keepdim,
dtype=opmath_dtype,
out=None,
has_identity=True,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
)
result = sqrt(result)
return _maybe_convert_to_dtype(result, dtype) # type: ignore[return-value,arg-type]
def mean(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
keepdim: bool = False,
*,
dtype=None,
out=None,
) -> TensorLikeType:
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
if dtype is None:
dtype = a.dtype
# can't use out wrapper because of this argument
if out is not None and out.dtype != dtype:
raise RuntimeError("expected out dtype and dtype to match")
result = _reduction(
a,
prims.sum,
dims=dim,
keepdims=keepdim,
dtype=dtype,
out=None,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
)
if utils.is_integer_dtype(dtype):
raise RuntimeError("result type should be floating point or complex")
if isinstance(dim, int):
dim = (dim,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
result = true_divide(result, nelem)
result_dtype = a.dtype if dtype is None else dtype
result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[assignment]
if out is not None:
assert isinstance(out, TensorLike)
out = _maybe_resize_out(out, result.shape)
return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
return result
@register_decomposition(torch.ops.aten.std_mean.correction)
def std_mean(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: Optional[int] = None,
):
s = std(a, dim, unbiased, keepdim, correction=correction)
m = mean(a, dim, keepdim)
return s, m
def var_mean(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: Optional[int] = None,
):
v = var(a, dim, unbiased, keepdim, correction=correction)
m = mean(a, dim, keepdim)
return v, m
@register_decomposition(torch.ops.aten.addr)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("self", "vec1", "vec2"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def addr(
self: TensorLikeType,
vec1: TensorLikeType,
vec2: TensorLikeType,
beta: NumberType = 1,
alpha: NumberType = 1,
) -> TensorLikeType:
check(
vec1.ndim == 1,
lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
)
check(
vec2.ndim == 1,
lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
)
self = self.expand(vec1.shape[0], vec2.shape[0])
if utils.is_boolean_dtype(self.dtype):
# Integers are accepted for booleans
check(
is_weakly_lesser_type(type(beta), int),
lambda: f"expected bool/int beta but got {type(beta)}",
)
check(
is_weakly_lesser_type(type(alpha), int),
lambda: f"expected bool/int alpha but got {type(beta)}",
)
if not beta:
return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
else:
return torch.logical_or(
self,
torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
)
else:
check(
is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
)
check(
is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
)
if beta == 0:
# This means NaNs from self are dropped if beta is zero
return alpha * torch.outer(vec1, vec2)
else:
return beta * self + alpha * torch.outer(vec1, vec2)
def atleast_1d(
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_1d`."""
if not args and isinstance(arg, collections.Sequence):
args_ = arg
else:
assert not isinstance(arg, collections.Sequence)
args_ = (arg,) + args
res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
return res if len(res) > 1 else res[0]
# Helper function with assert to avoid MyPy error
# of incompatible type passed to unsqueeze
def _unsqueeze_atleast(
at_least_fn: Callable, dim: int, arg: TensorLikeType
) -> TensorLikeType:
arg_ = at_least_fn(arg)
assert isinstance(arg_, TensorLike)
return unsqueeze(arg_, dim)
def atleast_2d(
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_2d`."""
if not args and isinstance(arg, collections.Sequence):
args_ = arg
else:
assert not isinstance(arg, collections.Sequence)
args_ = (arg,) + args
unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
return res if len(res) > 1 else res[0]
def atleast_3d(
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_3d`."""
if not args and isinstance(arg, collections.Sequence):
args_ = arg
else:
assert not isinstance(arg, collections.Sequence)
args_ = (arg,) + args
unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
return res if len(res) > 1 else res[0]
def as_strided(
a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int = 0
) -> TensorLikeType:
return prims.as_strided(a, size, stride, storage_offset)
def broadcast_shapes(*shapes) -> ShapeType:
return torch.Size(_broadcast_shapes(*shapes))
def broadcast_tensors(*tensors) -> List[TensorLikeType]:
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
start = len(size) - len(a.shape)
dims = tuple(range(start, len(a.shape) + start))
return prims.broadcast_in_dim(a, size, dims)
@register_decomposition(torch.ops.aten.cat)
@out_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("tensors",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
if len(tensors) == 0:
msg = "cat expects at least one tensor, but received zero!"
raise ValueError(msg)
for tensor in tensors:
assert isinstance(tensor, TensorLike)
utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
dim = utils.canonicalize_dim(tensors[0].ndim, dim)
utils.validate_idx(tensors[0].ndim, dim)
# Filters tensors with one dimension of length zero
filtered = tuple(x for x in tensors if not (x.ndim == 1 and x.numel() == 0))
if len(filtered) == 0:
t = tensors[0]
# TODO: fix this to work with meta tensors
try:
requires_grad = any(x.requires_grad for x in tensors)
except Exception:
requires_grad = False
return empty((0,), dtype=t.dtype, device=t.device, requires_grad=requires_grad)
return prims.cat(filtered, dim)
@out_wrapper
def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
aligned_tensors = tuple(
x if x.ndim > 1 else prims.expand_dims(x, list(range(x.ndim, 2)))
for x in tensors
)
return cat(aligned_tensors, 1)
@out_wrapper
def dstack(tensors: TensorSequenceType) -> TensorLikeType:
check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
aligned_tensors = atleast_3d(*tensors)
return cat(aligned_tensors, 2)
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
if chunks <= 0:
msg = "Expected at least one chunk, but got {0}!".format(chunks)
raise ValueError(msg)
dim = utils.canonicalize_dim(a.ndim, dim)
length = a.shape[dim]
chunk_size = math.ceil(length / chunks)
full_chunks = math.floor(length / chunk_size)
tail_chunk_size = length % chunk_size
result = []
for i in range(full_chunks):
result.append(narrow(a, dim, i * chunk_size, chunk_size))
if tail_chunk_size != 0:
result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
return tuple(result)
# Note: flatten, unlike prim.collapse and prim.collapse_view has an inclusive end_dim
# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
# a 0D tensor is flattened, in which case it's returned in 1D)
def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
start_dim = utils.canonicalize_dim(a.ndim, start_dim)
end_dim = utils.canonicalize_dim(a.ndim, end_dim)
# Short-circuits on no-op
if start_dim == end_dim and a.ndim != 0:
return a
# Tries to take a view
# TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim + 1)
if new_shape is not None:
return prims.collapse_view(a, start_dim, end_dim + 1)
# Makes a copy if it can't make a view
return prims.collapse(a, start_dim, end_dim + 1)
@register_decomposition(torch.ops.aten.flip)
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
if not isinstance(dims, tuple) and not isinstance(dims, list):
raise ValueError("dims has to be a sequence of ints")
dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment]
utils.validate_no_repeating_dims(dims)
return prims.rev(a, dims)
def fliplr(a: TensorLikeType) -> TensorLikeType:
if a.ndim < 2:
raise RuntimeError("Input must be >= 2-d.")
return flip(a, (1,))
def flipud(a: TensorLikeType) -> TensorLikeType:
if a.ndim < 1:
raise RuntimeError("Input must be >= 1-d.")
return flip(a, (0,))
def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType:
dim = utils.canonicalize_dim(a.ndim, dim)
return prims.slice_in_dim(a, start, start + length, axis=dim)
# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
@register_decomposition(torch.ops.aten.permute, disable_meta=True)
def permute(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
_permutation = utils.canonicalize_dims(a.ndim, dims)
return prims.transpose(a, _permutation)
def _reshape_view_helper(
a: TensorLikeType, shape: ShapeType, *, allow_copy: bool
) -> TensorLikeType:
# NOTE: Reshape may be given a shape with a -1 length
# This indicates that the dimension's length should be inferred
# Creates a valid shape
for idx in range(len(shape)):
if shape[idx] == -1:
# Verifies there's only one dimension of length -1 in the shape
if shape.count(-1) > 1:
msg = "Can only infer the length of one dimension, but got shape {0}!".format(
str(shape)
)
raise ValueError(msg)
# TODO: improve error message
if a.numel() > 0:
length = reduce(
operator.floordiv, (x for x in shape if x != -1), a.numel()
)
else:
msg = "Cannot reshape a tensor of zero elements into shape {0} because the unspecified length is ambiguous!".format(
str(shape)
)
raise ValueError(msg)
shape = list(shape)
shape[idx] = length
break
# Short-circuits if shape is the same
utils.validate_shape(shape)
if tuple(a.shape) == tuple(shape):
return prims.view_of(a)
numel = reduce(operator.mul, shape) if len(shape) > 0 else 1
if a.numel() != numel:
msg = "Attempting to reshape a tensor with shape {0} and {1} elements to a shape {2} with {3} elements!".format(
str(a.shape), a.numel(), str(shape), numel
)
raise ValueError(msg)
# Special-cases tensors with no elements
if a.numel() == 0:
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
# Special-cases reshaping zero dim tensors
if a.ndim == 0:
_a = a
for length in shape:
assert length == 1
_a = unsqueeze(_a, -1)
return _a
# Special-cases reshaping to zero dim tensors
if len(shape) == 0:
_a = a
for length in a.shape:
assert length == 1
_a = squeeze(_a, -1)
return _a
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
# NOTE [Reshape Algorithm]
# This algorithm works by attempting to greedily construct the desired dimensions in
# the output shape, left to right. It does this by, conceptually, accumulating
# dimensions of the original tensor, also left to right, until the dimension
# can be constructed using prims.split_dim.
# The algorithm also has special handling for tail squeezes/unsqueezes, like
# if a reshape from (5, 5) to (5, 5, 1) or vice versa.
#
# This algorithm does not flatten the original tensor and then split dims as appropriate
# because that would create copies more often than this algorithm. flatten is the only
# operation below which can create a view or a copy, and while it prefers creating
# views it may sometimes create a copy if the tensor's strides do not permit a view.
# As a result, this algorithm tries to minimize flattening.
#
# Note that a better version of this algorithm may exist. Regions which could be
# flattened without creating a copy can be identified in advance, and that might
# allow fewer flatten calls or faster short-circuiting to make a copy.
idx = 0
a_ = a
for length in shape:
# Handles tail unsqueezes
if idx >= a_.ndim:
assert length == 1
last_dim = a_.ndim - 1
# NOTE: using split_dim instead of unsqueeze may seem silly here,
# but it's necessary to get the strides correct
a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
idx = idx + 1
continue
# Skips dimensions that are already the correct length
if length == a_.shape[idx]:
idx = idx + 1
continue
# Gathers enough original dimensions such that this new dimension can be created
# Note that this accumulation will terminate because we've verified a and the shape
# specify the same number of elements above
accum = a_.shape[idx]
end = idx
while accum % length != 0:
end = end + 1
accum = accum * a_.shape[end]
if end != idx:
# NOTE: in this case multiple dimensions must be flatten to create the desired dimension
# This flattening is why reshape sometimes creates a copy -- because flattening
# may return a view of a copy
# Checks if collapse can be a view and short-circuits to copying reshape if it can't
new_shape, new_strides = prims._collapse_view_helper(a_, idx, end + 1)
if new_shape is None:
if allow_copy:
return prims.reshape(a, shape)
msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format(
a.shape, a.stride(), shape
)
raise ValueError(msg)
a_ = flatten(a_, idx, end)
# Splits the (possibly flattened) dimension to create the desired dim length
if accum != length:
a_ = prims.split_dim(a_, idx, length)
idx = idx + 1
# Squeezes tail
while idx < a_.ndim:
assert a_.shape[idx] == 1
a_ = squeeze(a_, idx)
return a_
# TODO: Turn this into a decomposition (currently fails on reshape meta tests)
def reshape(a: TensorLikeType, shape: ShapeType) -> TensorLikeType:
return _reshape_view_helper(a, shape, allow_copy=True)
@register_decomposition(torch.ops.aten.roll)
def roll(
a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple()
) -> TensorLikeType:
"""Reference implementation of :func:`torch.roll`."""
dims = utils.canonicalize_dims(a.ndim, dims)
# ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
if not isinstance(shifts, Iterable):
shifts = (shifts,)
if not isinstance(dims, Iterable):
dims = (dims,)
# Avoid modulo by zero
if a.numel() == 0:
# Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
return clone(a)
len_shifts = len(shifts)
len_dims = len(dims)
if len_shifts != 1 or len_dims != 1:
if len_shifts == 0:
raise RuntimeError("`shifts` required")
# Takes care of the case when dims is not specified (default)
# By default, the tensor is flattened before shifting, after which the original shape is restored
if len_dims == 0 and len_shifts == 1:
return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
if len_shifts != len_dims:
raise RuntimeError(
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
)
assert len_dims > 1
tail_shifts = shifts[1:]
tail_dims = dims[1:]
first_dim_rolled = torch.roll(a, shifts[0], dims[0])
return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
# This path is taken when only one dimension is rolled
# For example to get `first_dim_rolled` above
dim = dims[0]
size = a.shape[dim]
start = (size - shifts[0]) % size
t0 = torch.narrow(a, dim, start, size - start)
t1 = torch.narrow(a, dim, 0, start)
return torch.cat((t0, t1), dim)
@register_decomposition(torch.ops.aten.rot90)
def rot90(
a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
) -> TensorLikeType:
"""Reference implementation of :func:`torch.rot90`."""
dims_ = utils.canonicalize_dims(a.ndim, dims)
# Required to silence MyPy errors
assert isinstance(dims_, (tuple, list))
dims = dims_
if len(dims) != 2:
raise RuntimeError(
f"expected total rotation dims == 2, but got dims = {len(dims)}"
)
if a.ndim < 2:
raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
if dims[0] == dims[1]:
raise RuntimeError(
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
)
k = k % 4 # Rotation direction is from the second towards the first axis for k < 0
if k == 1:
return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
elif k == 2:
return torch.flip(a, dims)
elif k == 3:
return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
else:
return clone(a)
def _check_stack_inputs(tensors: TensorSequenceType) -> None:
entry_shape = tensors[0].shape
for i in range(1, len(tensors)):
assert tensors[i].shape == entry_shape, (
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
f"and {tensors[i].shape} at entry {i}"
)
@register_decomposition(torch.ops.aten.stack)
@out_wrapper
def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
assert len(tensors) > 0, "stack expects a non-empty TensorList"
wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
# Refs need sparse support to check other condition
if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse:
_check_stack_inputs(tensors)
result_sizes = list(tensors[0].shape)
result_sizes.insert(wrapped_dim, len(tensors))
out = torch.cat(tensors, wrapped_dim)
return out.view(result_sizes)
# If dim == tensors[0].ndim, view cannot efficiently handle it
return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
@out_wrapper
def softmax(
a: TensorLikeType,
dim: int,
*,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
result_dtype = dtype or a.dtype
computation_dtype = utils.get_computation_dtype(a.dtype)
a_ = _maybe_convert_to_dtype(a, computation_dtype)
assert isinstance(a_, TensorLike) # to avoid MyPy error for amax
a_max = amax(a_, dim, keepdim=True)
a_exp = exp(a_ - a_max)
return _maybe_convert_to_dtype(
true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
) # type: ignore[return-value]
@out_wrapper
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
aligned_tensors = atleast_1d(*tensors)
if aligned_tensors[0].ndim == 1:
return cat(aligned_tensors, 0)
return cat(aligned_tensors, 1)
@out_wrapper
def vstack(tensors: TensorSequenceType) -> TensorLikeType:
check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
aligned_tensors = atleast_2d(*tensors)
return cat(aligned_tensors, 0)
# Note: although squeeze is documented as having the out= kwarg it doesn't
def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType:
if dim is not None:
dim = utils.canonicalize_dim(a.ndim, dim)
# Short-circuits if the tensor has no dimensions
if len(a.shape) == 0:
assert dim == 0
return prims.view_of(a)
# Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
if a.shape[dim] != 1:
return prims.view_of(a)
return prims.squeeze(a, (dim,))
dims = tuple(idx for idx in range(len(a.shape)) if a.shape[idx] == 1)
return prims.squeeze(a, dims)
# Note: does not work with TensorMetas because of data-dependent control-flow
def tensor_split(
a: TensorLikeType,
indices_or_sections: Union[Tensor, DimsType],
dim: int = 0,
) -> Tuple[TensorLikeType, ...]:
_dim = utils.canonicalize_dim(a.ndim, dim)
if a.ndim == 0:
msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
raise ValueError(msg)
# If indices_or_sections is a tensor, it must be a CPU Long tensor
if isinstance(indices_or_sections, TensorLike):
if indices_or_sections.device != torch.device("cpu"):
msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {0}".format(
indices_or_sections.device
)
raise ValueError(msg)
if indices_or_sections.dtype != torch.long:
msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
" but received one with dtype {0}".format(indices_or_sections.dtype)
raise ValueError(msg)
# Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
if isinstance(indices_or_sections, int) or (
isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
):
sections: int = (
indices_or_sections # type: ignore[assignment]
if isinstance(indices_or_sections, Number)
else indices_or_sections.item()
)
if sections <= 0:
msg = "tensor_split: number of sections must be greater than 0, but was {0}".format(
sections
)
raise ValueError(msg)
splits = []
dim_size = a.shape[_dim]
min_split_size = math.floor(dim_size / sections)
num_splits_one_extra = dim_size % sections
start_idx = 0
for split_idx in range(sections):
split_size = (
min_split_size + 1
if (split_idx < num_splits_one_extra)
else min_split_size
)
s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim)
splits.append(s)
start_idx = start_idx + split_size
return tuple(splits)
# Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
else:
indices = indices_or_sections
if isinstance(indices_or_sections, TensorLike):
if indices_or_sections.ndim != 1:
msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
"but received a tensor with {0} dimensions".format(
indices_or_sections.ndim
)
raise ValueError(msg)
indices = indices_or_sections.tolist()
splits = []
start_idx = 0
for x in indices:
splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim))
start_idx = x
splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim))
return tuple(splits)
def hsplit(
a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]:
check(
a.ndim >= 1,
lambda: (
"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
+ str(a.ndim)
+ " dimensions!"
),
)
dim = 0 if a.ndim == 1 else 1
if isinstance(indices_or_sections, int):
split_size = indices_or_sections
check(
(split_size != 0 and a.shape[dim] % split_size == 0),
lambda: (
"torch.hsplit attempted to split along dimension "
+ str(dim)
+ ", but the size of the dimension "
+ str(a.shape[dim])
+ " is not divisible by the split_size "
+ str(split_size)
+ "!"
),
)
return tensor_split(a, split_size, dim)
check(
isinstance(indices_or_sections, (list, tuple)),
lambda: (
"hsplit(): received an invalid combination of arguments. "
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
f"but got type {type(indices_or_sections)}"
),
exc_type=TypeError,
)
split_sizes = indices_or_sections
return tensor_split(a, split_sizes, dim)
def vsplit(
a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]:
check(
a.ndim >= 2,
lambda: (
"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
+ str(a.ndim)
+ " dimensions!"
),
)
if isinstance(indices_or_sections, int):
split_size = indices_or_sections
check(
(split_size != 0 and a.shape[0] % split_size == 0),
lambda: (
"torch.vsplit attempted to split along dimension 0 "
+ ", but the size of the dimension "
+ str(a.shape[0])
+ " is not divisible by the split_size "
+ str(split_size)
+ "!"
),
)
return tensor_split(a, split_size, 0)
check(
isinstance(indices_or_sections, (list, tuple)),
lambda: (
"vsplit(): received an invalid combination of arguments. "
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
f"but got type {type(indices_or_sections)}"
),
exc_type=TypeError,
)
split_sizes = indices_or_sections
return tensor_split(a, split_sizes, 0)
def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
if a.ndim < 3:
raise RuntimeError(
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
)
if isinstance(sections, int) and (sections == 0 or a.shape[2] % sections != 0):
raise RuntimeError(
"torch._refs.dsplit attempted to split along dimension 2, "
+ f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
)
return tensor_split(a, sections, 2)
@register_decomposition(torch.ops.aten.t.default)
def t(a: TensorLikeType):
# TODO: Add sparse support
# if a.is_sparse:
# sparse_dim = a.sparse_dim()
# dense_dim = a.dense_dim()
# if not (sparse_dim <= 2 and dense_dim == 0):
# raise RuntimeError(
# f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
# f"{dense_dim} dense dimensions"
# )
if a.ndim > 2:
raise RuntimeError(
f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
)
return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
_dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
if a.ndim <= 1 or dim0 == dim1:
return prims.view_of(a)
_permutation = list(range(0, a.ndim))
_permutation[_dim0] = _dim1
_permutation[_dim1] = _dim0
return prims.transpose(a, _permutation)
# Aliases for transpose
swap_axes = transpose
@register_decomposition(torch.ops.aten.unsqueeze)
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
# Note that unsqueeze canonicalizes with rank + 1 because it allows
# a new innermost dimension to be specified
dim = utils.canonicalize_dim(a.ndim + 1, dim)
return prims.expand_dims(a, (dim,))
# TODO: Turn this into a decomposition (currently fails on reshape meta tests)
def view(a: TensorLikeType, shape: ShapeType) -> TensorLikeType:
return _reshape_view_helper(a, shape, allow_copy=False)
def ravel(a: TensorLikeType) -> TensorLikeType:
return reshape(a, (-1,))
@out_wrapper
def empty(
*shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> TensorLikeType:
shape = utils.extract_shape_from_varargs(shape)
strides = utils.make_contiguous_strides_for(shape)
return empty_strided(
shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
)
def empty_like(
a: TensorLikeType,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> TensorLikeType:
dtype = a.dtype if dtype is None else dtype
device = a.device if device is None else device
strides: Tuple[int, ...]
if a.numel() == 0:
strides = a.stride()
else:
strides = utils.compute_elementwise_output_strides(a)
return empty_strided(
a.shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
)
# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
def empty_strided(
shape: Union[ShapeType, Tuple[ShapeType]],
strides: StrideType,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> TensorLikeType:
shape = utils.extract_shape_from_varargs(shape)
dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device("cpu") if device is None else device
return prims.empty_strided(
shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
)
@out_wrapper
def full(
shape: ShapeType,
fill_value: NumberType,
*,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> TensorLikeType:
e = empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
return fill(e, fill_value)
def full_like(
a: TensorLikeType,
fill_value: NumberType,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> TensorLikeType:
e = empty_like(a, dtype=dtype, device=device, requires_grad=requires_grad)
return fill(e, fill_value)
ones = partial(full, fill_value=True)
ones_like = partial(full_like, fill_value=True)
zeros = partial(full, fill_value=False)
zeros_like = partial(full_like, fill_value=False)
def uniform(
shape: ShapeType,
low: Union[bool, int, float] = 0.0,
high: Union[bool, int, float] = 1.0,
*,
dtype: torch.dtype,
device: DeviceLikeType,
) -> TensorLikeType:
utils.validate_shape(shape)
assert isinstance(low, (bool, int, float))
assert isinstance(high, (bool, int, float))
low = float(low)
high = float(high)
assert isinstance(dtype, torch.dtype)
device = utils.canonicalize_device(device)
return prims.uniform(shape, low=low, high=high, dtype=dtype, device=device)
def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
python_type = utils.dtype_to_type(a.dtype)
if isinstance(value, Number):
value_type = type(value)
else:
# NOTE: Could not use value = item(value) as it resulted in
# RuntimeError: Cannot cast FakeTensor(cpu) to number
value_ndim = value.ndim
check(
value_ndim == 0,
lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
)
value_type = utils.dtype_to_type(value.dtype)
if value_type is complex:
# only downcasting from complex to lower type is not allowed.
# We allow casting `value` to lower type for other case
# Eg. float -> int.
# Ref: https://github.com/pytorch/pytorch/issues/79195
check(
utils.is_weakly_lesser_type(value_type, python_type),
lambda: f"could not convert to type {python_type} without overflow",
)
# Since `where` allows type-promotion,
# cast value to correct type before passing to `where`
if isinstance(value, Number):
return where(mask, python_type(value), a)
assert isinstance(value, TensorLike)
return where(mask, prims.to_dtype(value, a.dtype), a)
# TODO: add OpInfo for torch.equal and refs.equal
def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
utils.check_same_dtype(a, b)
# Shape check
if a.ndim != b.ndim:
return False
for x, y in zip(a.shape, b.shape):
if x != y:
return False
# Short-circuits if there are no elements to validate
if a.numel() == 0:
return True
return item(all(eq(a, b))) # type: ignore[return-value]
# populate the decomp table
import torch._refs.nn.functional
import torch._refs.special