SymInt fixes from symbolic-shapes branch (#86242)

symintify a few inplace meta functions

symintify resize_(), nbytes(), functionalization input mutations

meta funcs for avg_pool2d_backward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86242
Approved by: https://github.com/Chillee
This commit is contained in:
Edward Z. Yang 2022-10-04 19:07:32 -07:00 committed by PyTorch MergeBot
parent ac25c210e5
commit d07b85393a
7 changed files with 175 additions and 95 deletions

View file

@ -294,6 +294,14 @@ class TORCH_API TensorBase {
return impl_->numel() * impl_->itemsize();
}
c10::SymInt sym_nbytes() const {
TORCH_CHECK(layout () != at::kSparse,
"nbytes is not defined for sparse tensors. If you want the size of the constituent " \
"tensors, add the nbytes of the indices and values. If you want the size of the " \
"equivalent dense tensor, multiply numel() by element_size()");
return impl_->sym_numel() * impl_->itemsize();
}
int64_t numel() const {
return impl_->numel();
}

View file

@ -2147,7 +2147,7 @@
QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized
autogen: _empty_per_channel_affine_quantized.out
- func: resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
- func: resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
device_check: NoCheck

View file

@ -733,7 +733,6 @@ symbolic_aot_autograd_failures = {
xfail('amax', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('amin', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('as_strided', ''), # Tensor-likes are not close!
xfail('atanh', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@ -748,7 +747,6 @@ symbolic_aot_autograd_failures = {
xfail('combinations', ''), # aten.masked_select.default
xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition
xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
xfail('copysign', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition
xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition
@ -786,7 +784,6 @@ symbolic_aot_autograd_failures = {
xfail('fft.rfft2', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('fft.rfft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('fft.rfftn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('fill', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('flatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('fmax', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
xfail('fmin', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
@ -859,7 +856,6 @@ symbolic_aot_autograd_failures = {
xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t...
xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.normalize', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos...
@ -870,18 +866,14 @@ symbolic_aot_autograd_failures = {
xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to...
xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo...
xfail('max', 'binary'), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec...
xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('maximum', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('mean', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('median', ''), # could not find kernel
xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('min', 'binary'), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('min', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec...
xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('minimum', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('msort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('mv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@ -915,7 +907,6 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.dropout3d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.dropout', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('nn.functional.embedding_bag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.embedding', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Cannot call numel() on tensor with symbol...
xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g...
xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g...
@ -945,7 +936,6 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta ...
xfail('nn.functional.max_unpool3d', ''), # aten.max_unpool3d.default - couldn't find symbolic meta funct...
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta ...
xfail('nn.functional.mish', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('nn.functional.mse_loss', ''), # Unable to cast Python instance to C++ type (#define PYBIND11_DETA...
xfail('nn.functional.multi_margin_loss', ''), # could not find kernel
xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel
@ -963,7 +953,6 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.poisson_nll_loss', ''), # aten.add_.Tensor - couldn't find symbolic meta function/d...
xfail('nn.functional.prelu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
xfail('nn.functional.silu', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel
xfail('nn.functional.triplet_margin_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.triplet_margin_with_distance_loss', ''), # Cannot call sizes() on tensor with symbo...
@ -971,7 +960,6 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('norm', 'fro'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('normal', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides
@ -1047,7 +1035,6 @@ symbolic_aot_autograd_failures = {
xfail('view_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('vstack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('zero_', ''), # aten.zero_.default - couldn't find symbolic meta function/decomposition
}
def _test_aot_autograd_helper(self, device, dtype, op):

View file

@ -1099,7 +1099,6 @@ symbolic_tensor_failures = {
xfail('fft.rfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.rfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.rfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fill', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('unflatten', ''), # RuntimeError: Trying to call aten.size on a tensor with symbolic shapes...
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
xfail('gather', ''), # aten.gather.default - couldn't find symbolic meta function/decomposition
@ -1113,7 +1112,6 @@ symbolic_tensor_failures = {
xfail('index_add', ''), # Float
xfail('index_copy', ''), # Expected a long tensor for index, but got Float
xfail('index_fill', ''), # aten.index_fill.int_Scalar - couldn't find symbolic meta function/decomposition
xfail('index_put', ''), # aten.index_put.default - couldn't find symbolic meta function/decomposition
xfail('index_reduce', ''), # Float
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('isclose', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
@ -1308,7 +1306,6 @@ symbolic_tensor_failures = {
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('zero_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition
}
symbolic_tensor_segfaults = {

View file

@ -2443,6 +2443,7 @@ def register_inplace(aten_op, outplace_op):
register_inplace(aten.add_, aten.add)
register_inplace(aten.sub_, aten.sub)
register_inplace(aten.mul_, aten.mul)
register_inplace(aten.relu_, aten.relu)
register_inplace(aten.hardtanh_, aten.hardtanh)

View file

@ -371,6 +371,133 @@ def meta_conv(
return out
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
def check_dim_size(tensor, dim, dim_size, size):
check(
tensor.dim() == dim and tensor.shape[dim_size] == size,
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
)
@register_meta(aten.avg_pool2d.default, register_dispatcher=False)
def meta_avg_pool2d(
input,
kernel_size,
stride=(),
padding=(0,),
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
def unpack(name, val):
check(
len(val) in [1, 2],
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
)
H = val[0]
W = H if len(val) == 1 else val[1]
return H, W
kH, kW = unpack("kernel_size", kernel_size)
check(
len(stride) in [0, 1, 2],
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
if len(stride) == 0:
dH, dW = kH, kW
elif len(stride) == 1:
dH, dW = stride[0], stride[0]
else:
dH, dW = unpack("stride", stride)
padH, padW = unpack("padding", padding)
check(
divisor_override is None or divisor_override != 0,
lambda: "divisor must be not zero",
)
nbatch = input.size(-4) if input.dim() == 4 else 1
nInputPlane = input.size(-3)
inputHeight = input.size(-2)
inputWidth = input.size(-1)
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
memory_format = utils.suggest_memory_format(input)
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
)
if input.dim() == 3:
size = [nInputPlane, outputHeight, outputWidth]
else:
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
)
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
def avg_pool2d_backward_shape_check(
input,
gradOutput,
nbatch,
kH,
kW,
dH,
dW,
padH,
padW,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
mem_format,
):
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
mem_format,
)
ndim = input.dim()
nOutputPlane = nInputPlane
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
@register_meta(aten._adaptive_avg_pool2d.default)
def meta_adaptive_avg_pool2d(self, output_size):
check(
@ -770,6 +897,39 @@ def meta_repeat(self, repeats):
return self.new_empty(target_size)
@register_meta(aten.zero_.default, register_dispatcher=False)
def meta_zero_(self):
return self
@register_meta(
[aten.fill.Tensor, aten.fill.Scalar, aten.fill_.Tensor, aten.fill_.Scalar],
register_dispatcher=False,
)
def meta_fill_(self, val):
return self
@register_meta(aten.relu_.default, register_dispatcher=False)
def meta_relu_(self):
return self
@register_meta(aten.index_put.default, register_dispatcher=False)
def meta_index_put(self, indices, values, accumulate=False):
return self.new_empty(self.size())
@register_meta(aten.masked_fill_.Scalar, register_dispatcher=False)
def meta_masked_fill_(self, mask, value):
return self
@register_meta(aten.index_put_.default, register_dispatcher=False)
def meta_index_put_(self, indices, values, accumulate=False):
return self
@register_meta(aten.alias.default, register_dispatcher=False)
def meta_alias(self):
return self.view(self.shape)
@ -1003,80 +1163,6 @@ def meta_max_pool2d_with_indices(
)
@register_meta(aten.avg_pool2d.default, register_dispatcher=False)
def meta_avg_pool2d(
input,
kernel_size,
stride=(),
padding=(0,),
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
def unpack(name, val):
check(
len(val) in [1, 2],
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
)
H = val[0]
W = H if len(val) == 1 else val[1]
return H, W
kH, kW = unpack("kernel_size", kernel_size)
check(
len(stride) in [0, 1, 2],
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
if len(stride) == 0:
dH, dW = kH, kW
elif len(stride) == 1:
dH, dW = stride[0], stride[0]
else:
dH, dW = unpack("stride", stride)
padH, padW = unpack("padding", padding)
check(
divisor_override is None or divisor_override != 0,
lambda: "divisor must be not zero",
)
nbatch = input.size(-4) if input.dim() == 4 else 1
nInputPlane = input.size(-3)
inputHeight = input.size(-2)
inputWidth = input.size(-1)
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
memory_format = utils.suggest_memory_format(input)
pool2d_shape_check(
input,
kH,
kW,
dH,
dW,
padH,
padW,
1,
1,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
memory_format,
)
if input.dim() == 3:
size = [nInputPlane, outputHeight, outputWidth]
else:
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
)
@register_meta([aten.full.default])
def full(size, fill_value, *args, **kwargs):
return torch.empty(size, *args, **kwargs)

View file

@ -75,15 +75,16 @@ void _propagate_functional_input_mutation(
// storage.
if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) {
} else {
if (unwrapped.nbytes() != wrapped_inner.nbytes()) {
if (unwrapped.sym_nbytes() != wrapped_inner.sym_nbytes()) {
// Functions might resize zero-sized inputs, which we need to reflect
// ehre.
unwrapped.resize_(wrapped_inner.sizes());
unwrapped.resize__symint(wrapped_inner.sym_sizes());
}
// If the input tensor's metadata was mutated, then use as_strided_()
// to propagate the metadata change.
if (unwrapped.sizes() != wrapped_inner.sizes()) {
unwrapped.as_strided_(wrapped_inner.sizes(), wrapped_inner.strides());
if (unwrapped.sym_sizes() != wrapped_inner.sym_sizes()) {
unwrapped.as_strided__symint(
wrapped_inner.sym_sizes(), wrapped_inner.sym_strides());
}
unwrapped.copy_(wrapped_inner);
}