mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE] Use squeeze/unsqueeze in im2col (#136006)
And move unsqeeze out of the dispatch, as it's dtype agnostic Pull Request resolved: https://github.com/pytorch/pytorch/pull/136006 Approved by: https://github.com/Skylion007, https://github.com/eqy
This commit is contained in:
parent
4237592b8f
commit
081c4a966d
2 changed files with 8 additions and 8 deletions
|
|
@ -91,7 +91,7 @@ void col2im_out_cuda_template(
|
|||
if (input.dim() == 2) {
|
||||
// Force batch
|
||||
batched_input = false;
|
||||
input = input.view({1, input.size(0), input.size(1)});
|
||||
input = input.unsqueeze(0);
|
||||
}
|
||||
|
||||
int64_t batch_size = input.size(0);
|
||||
|
|
@ -134,10 +134,10 @@ void col2im_out_cuda_template(
|
|||
output.mutable_data_ptr<scalar_t>(),
|
||||
output_batch_stride);
|
||||
|
||||
if (!batched_input) {
|
||||
output.resize_({n_output_plane, output_height, output_width});
|
||||
}
|
||||
});
|
||||
if (!batched_input) {
|
||||
output = output.squeeze(0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ static void im2col_out_cuda_template(
|
|||
|
||||
if (input.dim() == 3) {
|
||||
batched_input = false;
|
||||
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
||||
input = input.unsqueeze(0);
|
||||
}
|
||||
|
||||
int64_t batch_size = input.size(0);
|
||||
|
|
@ -131,10 +131,10 @@ static void im2col_out_cuda_template(
|
|||
output_n.mutable_data_ptr<scalar_t>());
|
||||
}
|
||||
|
||||
if (!batched_input) {
|
||||
output.resize_({n_output_plane, output_length});
|
||||
}
|
||||
});
|
||||
if (!batched_input) {
|
||||
output = output.squeeze(0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in a new issue