Enabling concat fast path for channels last inputs (#39448)

Summary:
Updates concat kernel for contiguous input to support channels_last contig tensors.

This was tried on squeezenet model on pixel-2 device. It improves model perf by about 25%.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39448

Test Plan: test_cat_in_channels_last

Differential Revision: D22160526

Pulled By: kimishpatel

fbshipit-source-id: 6eee6e74b8a5c66167828283d16a52022a16997f
This commit is contained in:
Kimish Patel 2020-06-23 12:59:10 -07:00 committed by Facebook GitHub Bot
parent 27982d5711
commit 6a421d50ab
3 changed files with 38 additions and 18 deletions

View file

@ -145,6 +145,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
auto first_tensor_mem_format = tensors[0].suggest_memory_format();
for (int i = 0; i < tensors.size(); i++) {
auto const &tensor = tensors[i];
if (should_skip(tensor)) {
@ -155,7 +156,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i);
cat_dim_size += tensor.size(dim);
if (!tensor.is_contiguous()) {
if (!tensor.is_contiguous(first_tensor_mem_format)) {
allContiguous = false;
}
@ -170,11 +171,14 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// compute the size of the result
auto result_size = notSkippedTensor.sizes().vec();
result_size[dim] = cat_dim_size;
result.resize_(result_size);
result.resize_(result_size, first_tensor_mem_format);
if (result.numel() == 0) {
return result;
}
// fast path for single thread when both inputs and result are contiguous and not empty
allContiguous = allContiguous && result.is_contiguous(first_tensor_mem_format);
bool use_serial_kernel = result.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
allContiguous = allContiguous && result.is_contiguous();
ScalarType dtype = notSkippedTensor.scalar_type();
if (use_serial_kernel && allContiguous && no_type_promotion && (dtype == ScalarType::Double || dtype == ScalarType::Float)) {
cat_serial_stub(kCPU, result, tensors, dim);
@ -182,7 +186,9 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
}
int64_t offset = 0;
if (reuse_iterator && result.is_contiguous() && no_type_promotion) {
if (reuse_iterator &&
result.is_contiguous(first_tensor_mem_format) &&
no_type_promotion) {
auto source_slice = notSkippedTensor;
auto slice_dim_size = source_slice.size(dim);
auto result_slice = result.narrow(dim, 0, slice_dim_size);

View file

@ -20,27 +20,19 @@ struct InputMeta {
template <typename scalar_t>
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
auto size = result.sizes().vec();
int64_t outer = 1, inner = 1;
for (int64_t i = 0; i < dim; i++) {
outer *= size[i];
}
for (int64_t i = dim + 1; i < size.size(); i++) {
inner *= size[i];
}
int64_t outer = result.numel() / (result.size(dim) * result.stride(dim));
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
for (auto const &tensor : tensors) {
inputs.emplace_back(tensor, dim, inner);
inputs.emplace_back(tensor, dim, result.stride(dim));
}
using Vec = vec256::Vec256<scalar_t>;
int64_t offset = 0;
for (int64_t i = 0; i < outer; i++) {
scalar_t* result_ptr = result_data;
for (int64_t i = 0; i < outer; ++i) {
for (int64_t j = 0; j < ninputs; j++) {
scalar_t* result_ptr = result_data + offset;
int64_t local_inner = inputs[j].inner_size;
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
if (local_inner < Vec::size()) {
@ -57,7 +49,7 @@ void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
input_ptr,
local_inner);
}
offset += local_inner;
result_ptr += local_inner;
}
}
}

View file

@ -7065,6 +7065,28 @@ class TestTorchDeviceType(TestCase):
res2 = torch.cat((x, y), out=z)
self.assertEqual(res1, res2)
@onlyCPU
def test_cat_in_channels_last(self, device):
for dim in range(4):
x = torch.randn((4, 15, 8, 8), device=device)
y = torch.randn(x.shape, device=device)
res1 = torch.cat((x, y), dim=dim)
x = x.clone().contiguous(memory_format=torch.channels_last)
y = y.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), dim=dim)
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)
# Size larger than grain size.
x = torch.randn((4, 15, 256, 256), device=device)
y = torch.randn(x.shape, device=device)
res1 = torch.cat((x, y), dim=dim)
x = x.clone().contiguous(memory_format=torch.channels_last)
y = y.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), dim=dim)
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)
@onlyCUDA
def test_cat_preserve_channels_last(self, device):
x = torch.randn((4, 3, 8, 8), device=device)