From e497aa1e3535b223d9e7bcdb00a78e0c1b4ce3ec Mon Sep 17 00:00:00 2001 From: wuhuikx Date: Tue, 16 Oct 2018 20:29:06 -0700 Subject: [PATCH] Optimize UpsampleNearest Op (#12151) Summary: Optimize the UpsampleNearest Op. 1. Add OMP 2. revise the translated_idx method Pull Request resolved: https://github.com/pytorch/pytorch/pull/12151 Differential Revision: D10362856 Pulled By: ezyang fbshipit-source-id: 535a4b87c7423942217f2d79bedc463a0617c67a --- modules/detectron/upsample_nearest_op.h | 39 +++++++++++++------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/modules/detectron/upsample_nearest_op.h b/modules/detectron/upsample_nearest_op.h index 7de024f6765..ba5890400a9 100644 --- a/modules/detectron/upsample_nearest_op.h +++ b/modules/detectron/upsample_nearest_op.h @@ -35,22 +35,6 @@ class UpsampleNearestOp final : public Operator { USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { - auto translate_idx = [](int ii, int d1, int d2, int d3, int scale_factor) { - int x, y, z, w; - w = ii % d3; - ii = ii/d3; - z = ii % d2; - ii = ii/d2; - y = ii % d1; - ii = ii/d1; - x = ii; - w = w/scale_factor; - z = z/scale_factor; - d2 /= scale_factor; - d3 /= scale_factor; - return (((x*d1+y)*d2)+z)*d3+w; - }; - auto& X = Input(0); auto* Y = Output(0); auto out_shape = X.dims().vec(); @@ -73,11 +57,28 @@ class UpsampleNearestOp final : public Operator { const T *input_data = X.template data(); T *output_data = Y->template mutable_data(); + int scaled_d2 = d2 / scale_; + int scaled_d3 = d3 / scale_; - for (int ii = 0; ii < Y->size(); ii++) { - int ipidx = translate_idx(ii, d1, d2, d3, scale_); - output_data[ii] = input_data[ipidx]; +#ifdef _OPENMP +#if (_OPENMP >= 201307) +#pragma omp parallel for simd +#else +#pragma omp parallel for +#endif +#endif + for (int i = 0; i < d1; ++i) { + for (int j = 0; j < d2; ++j) { + for (int u = 0; u < d3; ++u) { + int ii = (i * d2 + j) * d3 + u; + int scaled_u = u / scale_; + int scaled_j = j / scale_; + int ipidx = ((i * scaled_d2) + scaled_j) * scaled_d3 + scaled_u; + output_data[ii] = input_data[ipidx]; + } + } } + return true; }