mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
SymInt fixes from symbolic-shapes branch (#86242)
symintify a few inplace meta functions symintify resize_(), nbytes(), functionalization input mutations meta funcs for avg_pool2d_backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/86242 Approved by: https://github.com/Chillee
This commit is contained in:
parent
ac25c210e5
commit
d07b85393a
7 changed files with 175 additions and 95 deletions
|
|
@ -294,6 +294,14 @@ class TORCH_API TensorBase {
|
|||
return impl_->numel() * impl_->itemsize();
|
||||
}
|
||||
|
||||
c10::SymInt sym_nbytes() const {
|
||||
TORCH_CHECK(layout () != at::kSparse,
|
||||
"nbytes is not defined for sparse tensors. If you want the size of the constituent " \
|
||||
"tensors, add the nbytes of the indices and values. If you want the size of the " \
|
||||
"equivalent dense tensor, multiply numel() by element_size()");
|
||||
return impl_->sym_numel() * impl_->itemsize();
|
||||
}
|
||||
|
||||
int64_t numel() const {
|
||||
return impl_->numel();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2147,7 +2147,7 @@
|
|||
QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized
|
||||
autogen: _empty_per_channel_affine_quantized.out
|
||||
|
||||
- func: resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
|
||||
- func: resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: method
|
||||
device_check: NoCheck
|
||||
|
|
|
|||
|
|
@ -733,7 +733,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('amax', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('amin', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('as_strided', ''), # Tensor-likes are not close!
|
||||
xfail('atanh', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
@ -748,7 +747,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('combinations', ''), # aten.masked_select.default
|
||||
xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('copysign', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
@ -786,7 +784,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('fft.rfft2', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('fft.rfft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('fft.rfftn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('fill', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('flatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('fmax', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('fmin', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
@ -859,7 +856,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t...
|
||||
xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('masked.normalize', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos...
|
||||
|
|
@ -870,18 +866,14 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to...
|
||||
xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo...
|
||||
xfail('max', 'binary'), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec...
|
||||
xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('maximum', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('mean', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('median', ''), # could not find kernel
|
||||
xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('min', 'binary'), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('min', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec...
|
||||
xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('minimum', ''), # aten.masked_fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('msort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('mv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
@ -915,7 +907,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.dropout3d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.dropout', ''), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.embedding_bag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.embedding', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Cannot call numel() on tensor with symbol...
|
||||
xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g...
|
||||
xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g...
|
||||
|
|
@ -945,7 +936,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta ...
|
||||
xfail('nn.functional.max_unpool3d', ''), # aten.max_unpool3d.default - couldn't find symbolic meta funct...
|
||||
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta ...
|
||||
xfail('nn.functional.mish', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.mse_loss', ''), # Unable to cast Python instance to C++ type (#define PYBIND11_DETA...
|
||||
xfail('nn.functional.multi_margin_loss', ''), # could not find kernel
|
||||
xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel
|
||||
|
|
@ -963,7 +953,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.poisson_nll_loss', ''), # aten.add_.Tensor - couldn't find symbolic meta function/d...
|
||||
xfail('nn.functional.prelu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
|
||||
xfail('nn.functional.silu', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel
|
||||
xfail('nn.functional.triplet_margin_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.triplet_margin_with_distance_loss', ''), # Cannot call sizes() on tensor with symbo...
|
||||
|
|
@ -971,7 +960,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('norm', 'fro'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('normal', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
@ -1047,7 +1035,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('view_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('vstack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('zero_', ''), # aten.zero_.default - couldn't find symbolic meta function/decomposition
|
||||
}
|
||||
|
||||
def _test_aot_autograd_helper(self, device, dtype, op):
|
||||
|
|
|
|||
|
|
@ -1099,7 +1099,6 @@ symbolic_tensor_failures = {
|
|||
xfail('fft.rfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('fft.rfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('fft.rfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('fill', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
|
||||
xfail('unflatten', ''), # RuntimeError: Trying to call aten.size on a tensor with symbolic shapes...
|
||||
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('gather', ''), # aten.gather.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
@ -1113,7 +1112,6 @@ symbolic_tensor_failures = {
|
|||
xfail('index_add', ''), # Float
|
||||
xfail('index_copy', ''), # Expected a long tensor for index, but got Float
|
||||
xfail('index_fill', ''), # aten.index_fill.int_Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_put', ''), # aten.index_put.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_reduce', ''), # Float
|
||||
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('isclose', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
|
||||
|
|
@ -1308,7 +1306,6 @@ symbolic_tensor_failures = {
|
|||
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('zero_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition
|
||||
}
|
||||
symbolic_tensor_segfaults = {
|
||||
|
|
|
|||
|
|
@ -2443,6 +2443,7 @@ def register_inplace(aten_op, outplace_op):
|
|||
|
||||
|
||||
register_inplace(aten.add_, aten.add)
|
||||
register_inplace(aten.sub_, aten.sub)
|
||||
register_inplace(aten.mul_, aten.mul)
|
||||
register_inplace(aten.relu_, aten.relu)
|
||||
register_inplace(aten.hardtanh_, aten.hardtanh)
|
||||
|
|
|
|||
|
|
@ -371,6 +371,133 @@ def meta_conv(
|
|||
return out
|
||||
|
||||
|
||||
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
|
||||
def check_dim_size(tensor, dim, dim_size, size):
|
||||
check(
|
||||
tensor.dim() == dim and tensor.shape[dim_size] == size,
|
||||
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
|
||||
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
|
||||
)
|
||||
|
||||
|
||||
@register_meta(aten.avg_pool2d.default, register_dispatcher=False)
|
||||
def meta_avg_pool2d(
|
||||
input,
|
||||
kernel_size,
|
||||
stride=(),
|
||||
padding=(0,),
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
divisor_override=None,
|
||||
):
|
||||
def unpack(name, val):
|
||||
check(
|
||||
len(val) in [1, 2],
|
||||
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
|
||||
)
|
||||
H = val[0]
|
||||
W = H if len(val) == 1 else val[1]
|
||||
return H, W
|
||||
|
||||
kH, kW = unpack("kernel_size", kernel_size)
|
||||
check(
|
||||
len(stride) in [0, 1, 2],
|
||||
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
||||
)
|
||||
if len(stride) == 0:
|
||||
dH, dW = kH, kW
|
||||
elif len(stride) == 1:
|
||||
dH, dW = stride[0], stride[0]
|
||||
else:
|
||||
dH, dW = unpack("stride", stride)
|
||||
|
||||
padH, padW = unpack("padding", padding)
|
||||
|
||||
check(
|
||||
divisor_override is None or divisor_override != 0,
|
||||
lambda: "divisor must be not zero",
|
||||
)
|
||||
|
||||
nbatch = input.size(-4) if input.dim() == 4 else 1
|
||||
nInputPlane = input.size(-3)
|
||||
inputHeight = input.size(-2)
|
||||
inputWidth = input.size(-1)
|
||||
|
||||
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
|
||||
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
|
||||
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
pool2d_shape_check(
|
||||
input,
|
||||
kH,
|
||||
kW,
|
||||
dH,
|
||||
dW,
|
||||
padH,
|
||||
padW,
|
||||
1,
|
||||
1,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
memory_format,
|
||||
)
|
||||
|
||||
if input.dim() == 3:
|
||||
size = [nInputPlane, outputHeight, outputWidth]
|
||||
else:
|
||||
size = [nbatch, nInputPlane, outputHeight, outputWidth]
|
||||
return torch.empty(
|
||||
size, dtype=input.dtype, device=input.device, memory_format=memory_format
|
||||
)
|
||||
|
||||
|
||||
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
|
||||
def avg_pool2d_backward_shape_check(
|
||||
input,
|
||||
gradOutput,
|
||||
nbatch,
|
||||
kH,
|
||||
kW,
|
||||
dH,
|
||||
dW,
|
||||
padH,
|
||||
padW,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
mem_format,
|
||||
):
|
||||
pool2d_shape_check(
|
||||
input,
|
||||
kH,
|
||||
kW,
|
||||
dH,
|
||||
dW,
|
||||
padH,
|
||||
padW,
|
||||
1,
|
||||
1,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
mem_format,
|
||||
)
|
||||
|
||||
ndim = input.dim()
|
||||
nOutputPlane = nInputPlane
|
||||
|
||||
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
|
||||
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
|
||||
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
|
||||
|
||||
|
||||
@register_meta(aten._adaptive_avg_pool2d.default)
|
||||
def meta_adaptive_avg_pool2d(self, output_size):
|
||||
check(
|
||||
|
|
@ -770,6 +897,39 @@ def meta_repeat(self, repeats):
|
|||
return self.new_empty(target_size)
|
||||
|
||||
|
||||
@register_meta(aten.zero_.default, register_dispatcher=False)
|
||||
def meta_zero_(self):
|
||||
return self
|
||||
|
||||
|
||||
@register_meta(
|
||||
[aten.fill.Tensor, aten.fill.Scalar, aten.fill_.Tensor, aten.fill_.Scalar],
|
||||
register_dispatcher=False,
|
||||
)
|
||||
def meta_fill_(self, val):
|
||||
return self
|
||||
|
||||
|
||||
@register_meta(aten.relu_.default, register_dispatcher=False)
|
||||
def meta_relu_(self):
|
||||
return self
|
||||
|
||||
|
||||
@register_meta(aten.index_put.default, register_dispatcher=False)
|
||||
def meta_index_put(self, indices, values, accumulate=False):
|
||||
return self.new_empty(self.size())
|
||||
|
||||
|
||||
@register_meta(aten.masked_fill_.Scalar, register_dispatcher=False)
|
||||
def meta_masked_fill_(self, mask, value):
|
||||
return self
|
||||
|
||||
|
||||
@register_meta(aten.index_put_.default, register_dispatcher=False)
|
||||
def meta_index_put_(self, indices, values, accumulate=False):
|
||||
return self
|
||||
|
||||
|
||||
@register_meta(aten.alias.default, register_dispatcher=False)
|
||||
def meta_alias(self):
|
||||
return self.view(self.shape)
|
||||
|
|
@ -1003,80 +1163,6 @@ def meta_max_pool2d_with_indices(
|
|||
)
|
||||
|
||||
|
||||
@register_meta(aten.avg_pool2d.default, register_dispatcher=False)
|
||||
def meta_avg_pool2d(
|
||||
input,
|
||||
kernel_size,
|
||||
stride=(),
|
||||
padding=(0,),
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
divisor_override=None,
|
||||
):
|
||||
def unpack(name, val):
|
||||
check(
|
||||
len(val) in [1, 2],
|
||||
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
|
||||
)
|
||||
H = val[0]
|
||||
W = H if len(val) == 1 else val[1]
|
||||
return H, W
|
||||
|
||||
kH, kW = unpack("kernel_size", kernel_size)
|
||||
check(
|
||||
len(stride) in [0, 1, 2],
|
||||
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
|
||||
)
|
||||
if len(stride) == 0:
|
||||
dH, dW = kH, kW
|
||||
elif len(stride) == 1:
|
||||
dH, dW = stride[0], stride[0]
|
||||
else:
|
||||
dH, dW = unpack("stride", stride)
|
||||
|
||||
padH, padW = unpack("padding", padding)
|
||||
|
||||
check(
|
||||
divisor_override is None or divisor_override != 0,
|
||||
lambda: "divisor must be not zero",
|
||||
)
|
||||
|
||||
nbatch = input.size(-4) if input.dim() == 4 else 1
|
||||
nInputPlane = input.size(-3)
|
||||
inputHeight = input.size(-2)
|
||||
inputWidth = input.size(-1)
|
||||
|
||||
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
|
||||
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
|
||||
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
pool2d_shape_check(
|
||||
input,
|
||||
kH,
|
||||
kW,
|
||||
dH,
|
||||
dW,
|
||||
padH,
|
||||
padW,
|
||||
1,
|
||||
1,
|
||||
nInputPlane,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
memory_format,
|
||||
)
|
||||
|
||||
if input.dim() == 3:
|
||||
size = [nInputPlane, outputHeight, outputWidth]
|
||||
else:
|
||||
size = [nbatch, nInputPlane, outputHeight, outputWidth]
|
||||
return torch.empty(
|
||||
size, dtype=input.dtype, device=input.device, memory_format=memory_format
|
||||
)
|
||||
|
||||
|
||||
@register_meta([aten.full.default])
|
||||
def full(size, fill_value, *args, **kwargs):
|
||||
return torch.empty(size, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -75,15 +75,16 @@ void _propagate_functional_input_mutation(
|
|||
// storage.
|
||||
if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) {
|
||||
} else {
|
||||
if (unwrapped.nbytes() != wrapped_inner.nbytes()) {
|
||||
if (unwrapped.sym_nbytes() != wrapped_inner.sym_nbytes()) {
|
||||
// Functions might resize zero-sized inputs, which we need to reflect
|
||||
// ehre.
|
||||
unwrapped.resize_(wrapped_inner.sizes());
|
||||
unwrapped.resize__symint(wrapped_inner.sym_sizes());
|
||||
}
|
||||
// If the input tensor's metadata was mutated, then use as_strided_()
|
||||
// to propagate the metadata change.
|
||||
if (unwrapped.sizes() != wrapped_inner.sizes()) {
|
||||
unwrapped.as_strided_(wrapped_inner.sizes(), wrapped_inner.strides());
|
||||
if (unwrapped.sym_sizes() != wrapped_inner.sym_sizes()) {
|
||||
unwrapped.as_strided__symint(
|
||||
wrapped_inner.sym_sizes(), wrapped_inner.sym_strides());
|
||||
}
|
||||
unwrapped.copy_(wrapped_inner);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue