diff --git a/caffe2/operators/normalize_op.cc b/caffe2/operators/normalize_op.cc index 5c7d3c76473..ec0600c3d94 100644 --- a/caffe2/operators/normalize_op.cc +++ b/caffe2/operators/normalize_op.cc @@ -1,33 +1,9 @@ #include "caffe2/operators/normalize_op.h" #include "caffe2/core/tensor.h" -#include "caffe2/utils/eigen_utils.h" namespace caffe2 { -template -void NormalizeOp::DoNormalize( - const T* xData, - T* yData, - const int m, - const int n, - const int sf) { - using InnerStride = Eigen::InnerStride; - using StridedVec = - Eigen::Map, 0, InnerStride>; - using ConstStridedVec = - Eigen::Map, 0, InnerStride>; - - for (int i = 0; i < n; ++i) { - auto base = (i / sf) * sf * m + (i % sf); - ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf)); - auto norm = xVec.template lpNorm<2>(); - norm = std::max(norm, kEps_); - StridedVec yVec(yData + base, 1, m, InnerStride(sf)); - yVec = xVec / norm; - } -}; - template void NormalizeGradientOp::DoNormalize( const T* xData, diff --git a/caffe2/operators/normalize_op.h b/caffe2/operators/normalize_op.h index 065484d6759..57c42e6a27d 100644 --- a/caffe2/operators/normalize_op.h +++ b/caffe2/operators/normalize_op.h @@ -3,6 +3,7 @@ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" +#include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" #define KEPS 1e-12f @@ -35,8 +36,27 @@ class NormalizeOp final : public Operator { private: const T kEps_ = KEPS; - void - DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf); + void DoNormalize( + const T* xData, + T* yData, + const int m, + const int n, + const int sf) { + using InnerStride = Eigen::InnerStride; + using StridedVec = + Eigen::Map, 0, InnerStride>; + using ConstStridedVec = + Eigen::Map, 0, InnerStride>; + + for (int i = 0; i < n; ++i) { + auto base = (i / sf) * sf * m + (i % sf); + ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf)); + auto norm = xVec.template lpNorm<2>(); + norm = std::max(norm, kEps_); + StridedVec yVec(yData + base, 1, m, InnerStride(sf)); + yVec = xVec / norm; + } + } }; template