pytorch/caffe2/operators/instance_norm_gradient_op.cc
Xiaomeng Yang 10e4137396 Optimize InstanceNormGradientOp (#22288)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22288

Optimize InstanceNormGradientOp

Benchmarks:

CPU with [N, C, H, W] = [128, 256, 56, 56],
NCHW order: 616ms -> 128ms
NHWC order: 1612ms -> 174ms

GPU with [N, C, H, W] = [128, 256, 112, 112],
NCHW order: 6450ms -> 37ms
NHWC order: 1419ms -> 82ms

Reviewed By: houseroad

Differential Revision: D16023630

fbshipit-source-id: 5af9bf1103cde2fc2bcb5cd5a057d039732f052e
2019-07-01 15:10:17 -07:00

268 lines
7.6 KiB
C++

#include "caffe2/operators/instance_norm_op.h"
#include <string>
#include <vector>
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
namespace {
template <typename T>
void ComputeInternalGradientsNHWC(
const int64_t N,
const int64_t C,
const int64_t HxW,
const T* dY,
const T* X,
T* ds,
T* db) {
EigenArrayMap<T> ds_arr(ds, C, N);
EigenArrayMap<T> db_arr(db, C, N);
for (int64_t i = 0; i < N; ++i) {
ConstEigenArrayMap<T> dY_arr(dY + i * C * HxW, C, HxW);
ConstEigenArrayMap<T> X_arr(X + i * C * HxW, C, HxW);
ds_arr.col(i) = dY_arr.col(0) * X_arr.col(0);
db_arr.col(i) = dY_arr.col(0);
for (int j = 1; j < HxW; ++j) {
ds_arr.col(i) += dY_arr.col(j) * X_arr.col(j);
db_arr.col(i) += dY_arr.col(j);
}
}
}
template <typename T>
void InstanceNormBackwardNCHW(
const int64_t N,
const int64_t C,
const int64_t HxW,
const T* dY,
const T* X,
const T* mean,
const T* rstd,
const T* gamma,
T* dX,
T* ds,
T* db) {
const T scale = T(1) / static_cast<T>(HxW);
ConstEigenArrayMap<T> dY_arr(dY, HxW, N * C);
ConstEigenArrayMap<T> X_arr(X, HxW, N * C);
for (int64_t i = 0; i < N * C; ++i) {
const T ds_sum = (dY_arr.col(i) * X_arr.col(i)).sum();
const T db_sum = dY_arr.col(i).sum();
const int64_t c = i % C;
const T c1 = rstd[i] * gamma[c];
T c2 = ds_sum * gamma[c];
T c3 = db_sum * gamma[c];
c2 = (c3 * mean[i] - c2) * rstd[i] * rstd[i] * rstd[i] * scale;
c3 = -c2 * mean[i] - c3 * rstd[i] * scale;
for (int64_t j = 0; j < HxW; ++j) {
const int64_t index = i * HxW + j;
dX[index] = c1 * dY[index] + c2 * X[index] + c3;
}
ds[i] = ds_sum;
db[i] = db_sum;
}
}
template <typename T>
void InstanceNormBackwardNHWC(
const int64_t N,
const int64_t C,
const int64_t HxW,
const T* dY,
const T* X,
const T* ds,
const T* db,
const T* mean,
const T* rstd,
const T* gamma,
T* dX,
T* c1,
T* c2,
T* c3) {
const T scale = T(1) / static_cast<T>(HxW);
ConstEigenArrayMap<T> ds_arr(ds, C, N);
ConstEigenArrayMap<T> db_arr(db, C, N);
ConstEigenArrayMap<T> mean_arr(mean, C, N);
ConstEigenArrayMap<T> rstd_arr(rstd, C, N);
ConstEigenVectorArrayMap<T> gamma_arr(gamma, C);
EigenArrayMap<T> c1_arr(c1, C, N);
EigenArrayMap<T> c2_arr(c2, C, N);
EigenArrayMap<T> c3_arr(c3, C, N);
c1_arr = rstd_arr.colwise() * gamma_arr;
c2_arr = ds_arr.colwise() * gamma_arr;
c3_arr = db_arr.colwise() * gamma_arr;
c2_arr = (c3_arr * mean_arr - c2_arr) * rstd_arr.cube() * scale;
c3_arr = -c2_arr * mean_arr - c3_arr * rstd_arr * scale;
for (int64_t i = 0; i < N; ++i) {
ConstEigenArrayMap<T> dY_arr(dY + i * HxW * C, C, HxW);
ConstEigenArrayMap<T> X_arr(X + i * HxW * C, C, HxW);
EigenArrayMap<T> dX_arr(dX + i * HxW * C, C, HxW);
dX_arr =
(dY_arr.colwise() * c1_arr.col(i) + X_arr.colwise() * c2_arr.col(i))
.colwise() +
c3_arr.col(i);
}
}
template <typename T>
void GammaBetaBackward(
const int64_t N,
const int64_t C,
const T* ds,
const T* db,
const T* mean,
const T* rstd,
T* dgamma,
T* dbeta) {
ConstEigenArrayMap<T> ds_arr(ds, C, N);
ConstEigenArrayMap<T> db_arr(db, C, N);
ConstEigenArrayMap<T> mean_arr(mean, C, N);
ConstEigenArrayMap<T> rstd_arr(rstd, C, N);
EigenVectorArrayMap<T> dgamma_arr(dgamma, C);
EigenVectorArrayMap<T> dbeta_arr(dbeta, C);
dgamma_arr =
(ds_arr.col(0) - db_arr.col(0) * mean_arr.col(0)) * rstd_arr.col(0);
dbeta_arr = db_arr.col(0);
for (int64_t i = 1; i < N; ++i) {
dgamma_arr +=
(ds_arr.col(i) - db_arr.col(i) * mean_arr.col(i)) * rstd_arr.col(i);
dbeta_arr += db_arr.col(i);
}
}
} // namespace
template <>
void InstanceNormGradientOp<float, CPUContext>::ComputeMoments(
const int64_t N,
const int64_t C,
const int64_t HxW,
const float* X,
float* mean,
float* rstd) {
if (order_ == StorageOrder::NCHW) {
const std::array<int, 2> X_dims = {static_cast<int>(N * C),
static_cast<int>(HxW)};
const std::array<int, 2> Y_dims = {static_cast<int>(N * C), 1};
math::Moments<float, CPUContext>(
2, X_dims.data(), Y_dims.data(), X, mean, rstd, &context_);
math::InvStd<float, CPUContext>(N * C, epsilon_, rstd, rstd, &context_);
} else {
const float c = 1.0f / static_cast<float>(HxW);
EigenArrayMap<float> mean_arr(mean, C, N);
EigenArrayMap<float> rstd_arr(rstd, C, N);
for (int64_t i = 0; i < N; ++i) {
ConstEigenArrayMap<float> X_arr(X + i * HxW * C, C, HxW);
mean_arr.col(i) = X_arr.col(0);
rstd_arr.col(i) = X_arr.col(0).square();
for (int64_t j = 1; j < HxW; ++j) {
mean_arr.col(i) += X_arr.col(j);
rstd_arr.col(i) += X_arr.col(j).square();
}
}
mean_arr *= c;
rstd_arr =
((rstd_arr * c - mean_arr.square()).max(0.0f) + epsilon_).rsqrt();
}
}
template <>
bool InstanceNormGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW(
const int64_t N,
const int64_t C,
const int64_t HxW,
const float* dY,
const float* X,
const float* mean,
const float* rstd,
const float* gamma,
float* dX,
float* dgamma,
float* dbeta) {
ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CPU));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
InstanceNormBackwardNCHW<float>(
N, C, HxW, dY, X, mean, rstd, gamma, dX, ds_data, db_data);
GammaBetaBackward<float>(N, C, ds_data, db_data, mean, rstd, dgamma, dbeta);
return true;
}
template <>
bool InstanceNormGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC(
const int64_t N,
const int64_t C,
const int64_t HxW,
const float* dY,
const float* X,
const float* mean,
const float* rstd,
const float* gamma,
float* dX,
float* dgamma,
float* dbeta) {
ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CPU));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
ComputeInternalGradientsNHWC<float>(N, C, HxW, dY, X, ds_data, db_data);
ReinitializeTensor(&c1_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&c2_, {N, C}, at::dtype<float>().device(CPU));
ReinitializeTensor(&c3_, {N, C}, at::dtype<float>().device(CPU));
float* c1_data = c1_.mutable_data<float>();
float* c2_data = c2_.mutable_data<float>();
float* c3_data = c3_.mutable_data<float>();
InstanceNormBackwardNHWC<float>(
N,
C,
HxW,
dY,
X,
ds_data,
db_data,
mean,
rstd,
gamma,
dX,
c1_data,
c2_data,
c3_data);
GammaBetaBackward<float>(N, C, ds_data, db_data, mean, rstd, dgamma, dbeta);
return true;
}
namespace {
class GetInstanceNormGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
std::vector<std::string> inputs = {I(0), I(1), I(2), GO(0)};
if (def_.output_size() >= 2) {
inputs.push_back(O(1));
}
if (def_.output_size() >= 3) {
inputs.push_back(O(2));
}
return SingleGradientDef(
"InstanceNormGradient",
"",
inputs,
std::vector<std::string>({GI(0), GI(1), GI(2)}));
}
};
} // namespace
REGISTER_CPU_OPERATOR(
InstanceNormGradient,
InstanceNormGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(InstanceNormGradient).NumInputs(4, 6).NumOutputs(3);
REGISTER_GRADIENT(InstanceNorm, GetInstanceNormGradient);
} // namespace caffe2