mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Optimize UpsampleNearest Op (#12151)
Summary: Optimize the UpsampleNearest Op. 1. Add OMP 2. revise the translated_idx method Pull Request resolved: https://github.com/pytorch/pytorch/pull/12151 Differential Revision: D10362856 Pulled By: ezyang fbshipit-source-id: 535a4b87c7423942217f2d79bedc463a0617c67a
This commit is contained in:
parent
ba25e13782
commit
e497aa1e35
1 changed files with 20 additions and 19 deletions
|
|
@ -35,22 +35,6 @@ class UpsampleNearestOp final : public Operator<Context> {
|
|||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override {
|
||||
auto translate_idx = [](int ii, int d1, int d2, int d3, int scale_factor) {
|
||||
int x, y, z, w;
|
||||
w = ii % d3;
|
||||
ii = ii/d3;
|
||||
z = ii % d2;
|
||||
ii = ii/d2;
|
||||
y = ii % d1;
|
||||
ii = ii/d1;
|
||||
x = ii;
|
||||
w = w/scale_factor;
|
||||
z = z/scale_factor;
|
||||
d2 /= scale_factor;
|
||||
d3 /= scale_factor;
|
||||
return (((x*d1+y)*d2)+z)*d3+w;
|
||||
};
|
||||
|
||||
auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
auto out_shape = X.dims().vec();
|
||||
|
|
@ -73,11 +57,28 @@ class UpsampleNearestOp final : public Operator<Context> {
|
|||
|
||||
const T *input_data = X.template data<T>();
|
||||
T *output_data = Y->template mutable_data<T>();
|
||||
int scaled_d2 = d2 / scale_;
|
||||
int scaled_d3 = d3 / scale_;
|
||||
|
||||
for (int ii = 0; ii < Y->size(); ii++) {
|
||||
int ipidx = translate_idx(ii, d1, d2, d3, scale_);
|
||||
output_data[ii] = input_data[ipidx];
|
||||
#ifdef _OPENMP
|
||||
#if (_OPENMP >= 201307)
|
||||
#pragma omp parallel for simd
|
||||
#else
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
#endif
|
||||
for (int i = 0; i < d1; ++i) {
|
||||
for (int j = 0; j < d2; ++j) {
|
||||
for (int u = 0; u < d3; ++u) {
|
||||
int ii = (i * d2 + j) * d3 + u;
|
||||
int scaled_u = u / scale_;
|
||||
int scaled_j = j / scale_;
|
||||
int ipidx = ((i * scaled_d2) + scaled_j) * scaled_d3 + scaled_u;
|
||||
output_data[ii] = input_data[ipidx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue