[BE] Use squeeze/unsqueeze in im2col (#136006)

And move unsqeeze out of the dispatch, as it's dtype agnostic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136006
Approved by: https://github.com/Skylion007, https://github.com/eqy
This commit is contained in:
Nikita Shulga 2024-09-14 00:35:37 +00:00 committed by PyTorch MergeBot
parent 4237592b8f
commit 081c4a966d
2 changed files with 8 additions and 8 deletions

View file

@ -91,7 +91,7 @@ void col2im_out_cuda_template(
if (input.dim() == 2) {
// Force batch
batched_input = false;
input = input.view({1, input.size(0), input.size(1)});
input = input.unsqueeze(0);
}
int64_t batch_size = input.size(0);
@ -134,10 +134,10 @@ void col2im_out_cuda_template(
output.mutable_data_ptr<scalar_t>(),
output_batch_stride);
if (!batched_input) {
output.resize_({n_output_plane, output_height, output_width});
}
});
if (!batched_input) {
output = output.squeeze(0);
}
}
} // namespace

View file

@ -81,7 +81,7 @@ static void im2col_out_cuda_template(
if (input.dim() == 3) {
batched_input = false;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
input = input.unsqueeze(0);
}
int64_t batch_size = input.size(0);
@ -131,10 +131,10 @@ static void im2col_out_cuda_template(
output_n.mutable_data_ptr<scalar_t>());
}
if (!batched_input) {
output.resize_({n_output_plane, output_length});
}
});
if (!batched_input) {
output = output.squeeze(0);
}
}
} // namespace