diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 1ba5d71b73c..8ec5670664a 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -294,6 +294,14 @@ class TORCH_API TensorBase { return impl_->numel() * impl_->itemsize(); } + c10::SymInt sym_nbytes() const { + TORCH_CHECK(layout () != at::kSparse, + "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ + "tensors, add the nbytes of the indices and values. If you want the size of the " \ + "equivalent dense tensor, multiply numel() by element_size()"); + return impl_->sym_numel() * impl_->itemsize(); + } + int64_t numel() const { return impl_->numel(); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 50fd5f1aed6..cff12cae337 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2147,7 +2147,7 @@ QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized autogen: _empty_per_channel_affine_quantized.out -- func: resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) +- func: resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method device_check: NoCheck diff --git a/functorch/test/test_aotdispatch.py b/functorch/test/test_aotdispatch.py index 48db55f35d8..0c45a5ef6d9 100644 --- a/functorch/test/test_aotdispatch.py +++ b/functorch/test/test_aotdispatch.py @@ -733,7 +733,6 @@ symbolic_aot_autograd_failures = { xfail('amax', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('amin', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('as_strided', ''), # Tensor-likes are not close! - xfail('atanh', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -748,7 +747,6 @@ symbolic_aot_autograd_failures = { xfail('combinations', ''), # aten.masked_select.default xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition - xfail('copysign', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition @@ -786,7 +784,6 @@ symbolic_aot_autograd_failures = { xfail('fft.rfft2', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.rfft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.rfftn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('fill', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('flatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fmax', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition xfail('fmin', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition @@ -859,7 +856,6 @@ symbolic_aot_autograd_failures = { xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t... xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.normalize', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... @@ -870,18 +866,14 @@ symbolic_aot_autograd_failures = { xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... - xfail('max', 'binary'), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec... xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('maximum', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('mean', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('median', ''), # could not find kernel xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides - xfail('min', 'binary'), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('min', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec... xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('minimum', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('msort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -915,7 +907,6 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.dropout3d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.dropout', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('nn.functional.embedding_bag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('nn.functional.embedding', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Cannot call numel() on tensor with symbol... xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g... xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g... @@ -945,7 +936,6 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta ... xfail('nn.functional.max_unpool3d', ''), # aten.max_unpool3d.default - couldn't find symbolic meta funct... xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta ... - xfail('nn.functional.mish', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('nn.functional.mse_loss', ''), # Unable to cast Python instance to C++ type (#define PYBIND11_DETA... xfail('nn.functional.multi_margin_loss', ''), # could not find kernel xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel @@ -963,7 +953,6 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.poisson_nll_loss', ''), # aten.add_.Tensor - couldn't find symbolic meta function/d... xfail('nn.functional.prelu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function... - xfail('nn.functional.silu', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel xfail('nn.functional.triplet_margin_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.triplet_margin_with_distance_loss', ''), # Cannot call sizes() on tensor with symbo... @@ -971,7 +960,6 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('norm', 'fro'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition xfail('normal', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1047,7 +1035,6 @@ symbolic_aot_autograd_failures = { xfail('view_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('vstack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('zero_', ''), # aten.zero_.default - couldn't find symbolic meta function/decomposition } def _test_aot_autograd_helper(self, device, dtype, op): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 11c82f30fb9..67bcfe7864b 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1099,7 +1099,6 @@ symbolic_tensor_failures = { xfail('fft.rfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.rfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.rfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('fill', ''), # The underlying op of 'aten.stride' has no overload name '_schema' xfail('unflatten', ''), # RuntimeError: Trying to call aten.size on a tensor with symbolic shapes... xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition xfail('gather', ''), # aten.gather.default - couldn't find symbolic meta function/decomposition @@ -1113,7 +1112,6 @@ symbolic_tensor_failures = { xfail('index_add', ''), # Float xfail('index_copy', ''), # Expected a long tensor for index, but got Float xfail('index_fill', ''), # aten.index_fill.int_Scalar - couldn't find symbolic meta function/decomposition - xfail('index_put', ''), # aten.index_put.default - couldn't find symbolic meta function/decomposition xfail('index_reduce', ''), # Float xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('isclose', ''), # The underlying op of 'aten.stride' has no overload name '_schema' @@ -1308,7 +1306,6 @@ symbolic_tensor_failures = { xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('zero_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition } symbolic_tensor_segfaults = { diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 3c814fe0f03..3095acb4391 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2443,6 +2443,7 @@ def register_inplace(aten_op, outplace_op): register_inplace(aten.add_, aten.add) +register_inplace(aten.sub_, aten.sub) register_inplace(aten.mul_, aten.mul) register_inplace(aten.relu_, aten.relu) register_inplace(aten.hardtanh_, aten.hardtanh) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 701888bb4fa..c0e86a03eeb 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -371,6 +371,133 @@ def meta_conv( return out +# from check_dim_size() in aten/src/ATen/TensorUtils.cpp. +def check_dim_size(tensor, dim, dim_size, size): + check( + tensor.dim() == dim and tensor.shape[dim_size] == size, + lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, " + + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}", + ) + + +@register_meta(aten.avg_pool2d.default, register_dispatcher=False) +def meta_avg_pool2d( + input, + kernel_size, + stride=(), + padding=(0,), + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + def unpack(name, val): + check( + len(val) in [1, 2], + lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", + ) + H = val[0] + W = H if len(val) == 1 else val[1] + return H, W + + kH, kW = unpack("kernel_size", kernel_size) + check( + len(stride) in [0, 1, 2], + lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", + ) + if len(stride) == 0: + dH, dW = kH, kW + elif len(stride) == 1: + dH, dW = stride[0], stride[0] + else: + dH, dW = unpack("stride", stride) + + padH, padW = unpack("padding", padding) + + check( + divisor_override is None or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) + + memory_format = utils.suggest_memory_format(input) + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format, + ) + + if input.dim() == 3: + size = [nInputPlane, outputHeight, outputWidth] + else: + size = [nbatch, nInputPlane, outputHeight, outputWidth] + return torch.empty( + size, dtype=input.dtype, device=input.device, memory_format=memory_format + ) + + +# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h. +def avg_pool2d_backward_shape_check( + input, + gradOutput, + nbatch, + kH, + kW, + dH, + dW, + padH, + padW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, +): + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, + ) + + ndim = input.dim() + nOutputPlane = nInputPlane + + check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane) + check_dim_size(gradOutput, ndim, ndim - 2, outputHeight) + check_dim_size(gradOutput, ndim, ndim - 1, outputWidth) + + @register_meta(aten._adaptive_avg_pool2d.default) def meta_adaptive_avg_pool2d(self, output_size): check( @@ -770,6 +897,39 @@ def meta_repeat(self, repeats): return self.new_empty(target_size) +@register_meta(aten.zero_.default, register_dispatcher=False) +def meta_zero_(self): + return self + + +@register_meta( + [aten.fill.Tensor, aten.fill.Scalar, aten.fill_.Tensor, aten.fill_.Scalar], + register_dispatcher=False, +) +def meta_fill_(self, val): + return self + + +@register_meta(aten.relu_.default, register_dispatcher=False) +def meta_relu_(self): + return self + + +@register_meta(aten.index_put.default, register_dispatcher=False) +def meta_index_put(self, indices, values, accumulate=False): + return self.new_empty(self.size()) + + +@register_meta(aten.masked_fill_.Scalar, register_dispatcher=False) +def meta_masked_fill_(self, mask, value): + return self + + +@register_meta(aten.index_put_.default, register_dispatcher=False) +def meta_index_put_(self, indices, values, accumulate=False): + return self + + @register_meta(aten.alias.default, register_dispatcher=False) def meta_alias(self): return self.view(self.shape) @@ -1003,80 +1163,6 @@ def meta_max_pool2d_with_indices( ) -@register_meta(aten.avg_pool2d.default, register_dispatcher=False) -def meta_avg_pool2d( - input, - kernel_size, - stride=(), - padding=(0,), - ceil_mode=False, - count_include_pad=True, - divisor_override=None, -): - def unpack(name, val): - check( - len(val) in [1, 2], - lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", - ) - H = val[0] - W = H if len(val) == 1 else val[1] - return H, W - - kH, kW = unpack("kernel_size", kernel_size) - check( - len(stride) in [0, 1, 2], - lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", - ) - if len(stride) == 0: - dH, dW = kH, kW - elif len(stride) == 1: - dH, dW = stride[0], stride[0] - else: - dH, dW = unpack("stride", stride) - - padH, padW = unpack("padding", padding) - - check( - divisor_override is None or divisor_override != 0, - lambda: "divisor must be not zero", - ) - - nbatch = input.size(-4) if input.dim() == 4 else 1 - nInputPlane = input.size(-3) - inputHeight = input.size(-2) - inputWidth = input.size(-1) - - outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) - outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) - - memory_format = utils.suggest_memory_format(input) - pool2d_shape_check( - input, - kH, - kW, - dH, - dW, - padH, - padW, - 1, - 1, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - memory_format, - ) - - if input.dim() == 3: - size = [nInputPlane, outputHeight, outputWidth] - else: - size = [nbatch, nInputPlane, outputHeight, outputWidth] - return torch.empty( - size, dtype=input.dtype, device=input.device, memory_format=memory_format - ) - - @register_meta([aten.full.default]) def full(size, fill_value, *args, **kwargs): return torch.empty(size, *args, **kwargs) diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index fe7116c1d61..b1f696ee3c7 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -75,15 +75,16 @@ void _propagate_functional_input_mutation( // storage. if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) { } else { - if (unwrapped.nbytes() != wrapped_inner.nbytes()) { + if (unwrapped.sym_nbytes() != wrapped_inner.sym_nbytes()) { // Functions might resize zero-sized inputs, which we need to reflect // ehre. - unwrapped.resize_(wrapped_inner.sizes()); + unwrapped.resize__symint(wrapped_inner.sym_sizes()); } // If the input tensor's metadata was mutated, then use as_strided_() // to propagate the metadata change. - if (unwrapped.sizes() != wrapped_inner.sizes()) { - unwrapped.as_strided_(wrapped_inner.sizes(), wrapped_inner.strides()); + if (unwrapped.sym_sizes() != wrapped_inner.sym_sizes()) { + unwrapped.as_strided__symint( + wrapped_inner.sym_sizes(), wrapped_inner.sym_strides()); } unwrapped.copy_(wrapped_inner); }