mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
And here we go
This commit is contained in:
parent
ec81140cc1
commit
bd631d11ad
4 changed files with 10 additions and 17 deletions
|
|
@ -244,10 +244,11 @@ kernel void upsample_bilinear2d_aa(
|
|||
uint thread_index [[thread_position_in_grid]]) {
|
||||
auto output_x = thread_index % output_sizes.x;
|
||||
auto output_y = thread_index / output_sizes.x;
|
||||
(void)align_corners; // Align corners is unused for AA algorithm
|
||||
auto x_center = area_pixel_compute_source_index(
|
||||
scales.x, output_x, align_corners, /*cubic=*/false);
|
||||
scales.x, output_x, /*align_corners=*/false, /*cubic=*/false);
|
||||
auto y_center = area_pixel_compute_source_index(
|
||||
scales.y, output_y, align_corners, /*cubic=*/false);
|
||||
scales.y, output_y, /*align_corners=*/false, /*cubic=*/false);
|
||||
auto clamped_scales = max(1.0, scales);
|
||||
auto x_min = max(0L, long(floor(x_center - clamped_scales.x + 1)));
|
||||
auto x_max = min(input_sizes.x, long(ceil(x_center + clamped_scales.x)));
|
||||
|
|
@ -271,7 +272,7 @@ kernel void upsample_bilinear2d_aa(
|
|||
outputData
|
||||
[n * output_strides.w + c * output_strides.z +
|
||||
output_x * output_strides.x + output_y * output_strides.y] =
|
||||
res / ws;
|
||||
static_cast<T>(res / ws);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -460,9 +461,11 @@ INSTANTIATE_UPSAMPLE_BILINEAR_AA(float);
|
|||
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC(half);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(half);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR_AA(half);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC(bfloat);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(bfloat);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR_AA(bfloat);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(bfloat);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -473,16 +473,4 @@ TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_mps)
|
|||
mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bilinear2d_aa");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
bool align_corners,
|
||||
std::optional<double> scales_h,
|
||||
std::optional<double> scales_w,
|
||||
const Tensor& grad_input) {
|
||||
mps::upsample_kernel_backward_out_template(
|
||||
grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w, "bilinear2d_aa");
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -12738,7 +12738,6 @@
|
|||
dispatch:
|
||||
CPU: _upsample_bilinear2d_aa_backward_out_cpu
|
||||
CUDA: _upsample_bilinear2d_aa_backward_out_cuda
|
||||
MPS: _upsample_bilinear2d_aa_backward_out_mps
|
||||
|
||||
- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
|
||||
python_module: nn
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ def mps_ops_grad_modifier(ops):
|
|||
'__getitem__': [torch.float16],
|
||||
'_segment_reduce': [torch.float16, torch.float32],
|
||||
'_chunk_cat': [torch.float16, torch.float32],
|
||||
'_upsample_bilinear2d_aa': None,
|
||||
'sparse.mmreduce': [torch.float32], # csr not supported
|
||||
'unique_consecutive': [torch.float16, torch.float32],
|
||||
'special_modified_bessel_i0': [torch.float16, torch.float32],
|
||||
|
|
@ -798,7 +799,7 @@ def mps_ops_modifier(ops):
|
|||
'unique': None,
|
||||
'vdot': None,
|
||||
'segment_reduce_': None,
|
||||
'_upsample_bilinear2d_aa': None,
|
||||
'_upsample_bilinear2d_aa': [torch.uint8],
|
||||
'geometric' : None,
|
||||
'geometric_': None,
|
||||
'log_normal_': None,
|
||||
|
|
@ -12523,6 +12524,8 @@ class TestConsistency(TestCaseMPS):
|
|||
return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6)
|
||||
if op.name == "nn.functional.interpolate":
|
||||
return (1e-3, 1e-4)
|
||||
if op.name == "_upsample_bilinear2d_aa":
|
||||
return (2e-5, 2e-6)
|
||||
if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']:
|
||||
# TODO: Investigate why this is needed
|
||||
# See https://github.com/pytorch/pytorch/issues/120237
|
||||
|
|
|
|||
Loading…
Reference in a new issue