mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enabling concat fast path for channels last inputs (#39448)
Summary: Updates concat kernel for contiguous input to support channels_last contig tensors. This was tried on squeezenet model on pixel-2 device. It improves model perf by about 25%. Pull Request resolved: https://github.com/pytorch/pytorch/pull/39448 Test Plan: test_cat_in_channels_last Differential Revision: D22160526 Pulled By: kimishpatel fbshipit-source-id: 6eee6e74b8a5c66167828283d16a52022a16997f
This commit is contained in:
parent
27982d5711
commit
6a421d50ab
3 changed files with 38 additions and 18 deletions
|
|
@ -145,6 +145,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
|
|||
|
||||
// compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
auto first_tensor_mem_format = tensors[0].suggest_memory_format();
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
auto const &tensor = tensors[i];
|
||||
if (should_skip(tensor)) {
|
||||
|
|
@ -155,7 +156,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
|
|||
check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i);
|
||||
cat_dim_size += tensor.size(dim);
|
||||
|
||||
if (!tensor.is_contiguous()) {
|
||||
if (!tensor.is_contiguous(first_tensor_mem_format)) {
|
||||
allContiguous = false;
|
||||
}
|
||||
|
||||
|
|
@ -170,11 +171,14 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
|
|||
// compute the size of the result
|
||||
auto result_size = notSkippedTensor.sizes().vec();
|
||||
result_size[dim] = cat_dim_size;
|
||||
result.resize_(result_size);
|
||||
result.resize_(result_size, first_tensor_mem_format);
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// fast path for single thread when both inputs and result are contiguous and not empty
|
||||
allContiguous = allContiguous && result.is_contiguous(first_tensor_mem_format);
|
||||
bool use_serial_kernel = result.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
|
||||
allContiguous = allContiguous && result.is_contiguous();
|
||||
ScalarType dtype = notSkippedTensor.scalar_type();
|
||||
if (use_serial_kernel && allContiguous && no_type_promotion && (dtype == ScalarType::Double || dtype == ScalarType::Float)) {
|
||||
cat_serial_stub(kCPU, result, tensors, dim);
|
||||
|
|
@ -182,7 +186,9 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
|
|||
}
|
||||
|
||||
int64_t offset = 0;
|
||||
if (reuse_iterator && result.is_contiguous() && no_type_promotion) {
|
||||
if (reuse_iterator &&
|
||||
result.is_contiguous(first_tensor_mem_format) &&
|
||||
no_type_promotion) {
|
||||
auto source_slice = notSkippedTensor;
|
||||
auto slice_dim_size = source_slice.size(dim);
|
||||
auto result_slice = result.narrow(dim, 0, slice_dim_size);
|
||||
|
|
|
|||
|
|
@ -20,27 +20,19 @@ struct InputMeta {
|
|||
|
||||
template <typename scalar_t>
|
||||
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
|
||||
auto size = result.sizes().vec();
|
||||
int64_t outer = 1, inner = 1;
|
||||
for (int64_t i = 0; i < dim; i++) {
|
||||
outer *= size[i];
|
||||
}
|
||||
for (int64_t i = dim + 1; i < size.size(); i++) {
|
||||
inner *= size[i];
|
||||
}
|
||||
int64_t outer = result.numel() / (result.size(dim) * result.stride(dim));
|
||||
scalar_t* result_data = result.data_ptr<scalar_t>();
|
||||
int64_t ninputs = tensors.size();
|
||||
std::vector<InputMeta> inputs;
|
||||
inputs.reserve(ninputs);
|
||||
for (auto const &tensor : tensors) {
|
||||
inputs.emplace_back(tensor, dim, inner);
|
||||
inputs.emplace_back(tensor, dim, result.stride(dim));
|
||||
}
|
||||
|
||||
|
||||
using Vec = vec256::Vec256<scalar_t>;
|
||||
int64_t offset = 0;
|
||||
for (int64_t i = 0; i < outer; i++) {
|
||||
scalar_t* result_ptr = result_data;
|
||||
for (int64_t i = 0; i < outer; ++i) {
|
||||
for (int64_t j = 0; j < ninputs; j++) {
|
||||
scalar_t* result_ptr = result_data + offset;
|
||||
int64_t local_inner = inputs[j].inner_size;
|
||||
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
|
||||
if (local_inner < Vec::size()) {
|
||||
|
|
@ -57,7 +49,7 @@ void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
|
|||
input_ptr,
|
||||
local_inner);
|
||||
}
|
||||
offset += local_inner;
|
||||
result_ptr += local_inner;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7065,6 +7065,28 @@ class TestTorchDeviceType(TestCase):
|
|||
res2 = torch.cat((x, y), out=z)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@onlyCPU
|
||||
def test_cat_in_channels_last(self, device):
|
||||
for dim in range(4):
|
||||
x = torch.randn((4, 15, 8, 8), device=device)
|
||||
y = torch.randn(x.shape, device=device)
|
||||
res1 = torch.cat((x, y), dim=dim)
|
||||
x = x.clone().contiguous(memory_format=torch.channels_last)
|
||||
y = y.clone().contiguous(memory_format=torch.channels_last)
|
||||
res2 = torch.cat((x, y), dim=dim)
|
||||
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# Size larger than grain size.
|
||||
x = torch.randn((4, 15, 256, 256), device=device)
|
||||
y = torch.randn(x.shape, device=device)
|
||||
res1 = torch.cat((x, y), dim=dim)
|
||||
x = x.clone().contiguous(memory_format=torch.channels_last)
|
||||
y = y.clone().contiguous(memory_format=torch.channels_last)
|
||||
res2 = torch.cat((x, y), dim=dim)
|
||||
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@onlyCUDA
|
||||
def test_cat_preserve_channels_last(self, device):
|
||||
x = torch.randn((4, 3, 8, 8), device=device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue