diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc index cfc3685fa00..73c77554cee 100644 --- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc +++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.cc @@ -1,5 +1,6 @@ #include "layernorm_fp16_fake_op.h" #include "caffe2/contrib/fakelowp/common.h" +#include "caffe2/contrib/fakelowp/fp16_fma.h" namespace caffe2 { @@ -20,33 +21,30 @@ void LayerNormFakeFp16Op::calcY( EigenArrayMap Y_arr(Y, N, M); T tmp = T(0); + for (int i = 0; i < M; ++i) { + T normFactor = T(T(1) / std_arr[i]); + fp16_wrap(&normFactor); + for (int j = 0; j < N; ++j) { + T normalized = T(X_arr.col(i)[j] - mean[i]); + fp16_wrap(&normalized); + normalized *= normFactor; + fp16_wrap(&normalized); + Y_arr.col(i)[j] = normalized; + } + } + if (gamma != nullptr && beta != nullptr) { ConstEigenVectorArrayMap gamma_arr(gamma, N); ConstEigenVectorArrayMap beta_arr(beta, N); for (int i = 0; i < M; ++i) { - T normFactor = T(T(1) / std_arr[i]); - fp16_wrap(&normFactor); - for (int j = 0; j < N; ++j) { - tmp = T(X_arr.col(i)[j] - mean[i]); - fp16_wrap(&tmp); - T normalized = tmp * normFactor; - fp16_wrap(&normalized); - tmp = normalized * gamma_arr[j] + beta_arr[j]; - fp16_wrap(&tmp); - Y_arr.col(i)[j] = tmp; + vector res(N); + for (int j = 0; j < N; j++) { + res[j] = beta[j]; } - } - } else { - for (int i = 0; i < M; ++i) { - T normFactor = T(T(1) / std_arr[i]); - fp16_wrap(&normFactor); - for (int j = 0; j < N; ++j) { - tmp = T(X_arr.col(i)[j] - mean[i]); - fp16_wrap(&tmp); - tmp *= normFactor; - fp16_wrap(&tmp); - Y_arr.col(i)[j] = tmp; + fake_fp16::fma_fp16(N, &Y_arr.col(i)[0], gamma, res.data()); + for (int j = 0; j < N; j++) { + Y_arr.col(i)[j] = res[j]; } } } @@ -98,39 +96,86 @@ void LayerNormFakeFp16Op::calcMeanStd( fp16_wrap(&inv_N_val); T tmp = T(0); - const int VEC_SIZE = 32; + constexpr int VEC_SIZE = 32; + std::vector inv_N_vec(VEC_SIZE, inv_N_val); + std::vector inv_N_prod_vec(VEC_SIZE, 0); std::vector avgVec(VEC_SIZE, T(0)); std::vector sqrVec(VEC_SIZE, T(0)); int numVecs = N / VEC_SIZE; int tailSize = N - (numVecs * VEC_SIZE); + + vector X_fp16(M * N); + fbgemm::RoundToFloat16( + X, X_fp16.data(), M * N, FLAGS_caffe2_fbgemm_fake_fp16_clamp); + for (int i = 0; i < M; ++i) { std::fill(avgVec.begin(), avgVec.end(), T(0)); std::fill(sqrVec.begin(), sqrVec.end(), T(0)); for (int j = 0; j < numVecs; ++j) { - for (int k = 0; k < VEC_SIZE; ++k) { - avgVec[k] = X_arr.col(i)[VEC_SIZE * j + k] * inv_N_val + avgVec[k]; - fp16_wrap(&avgVec[k]); - tmp = X_arr.col(i)[VEC_SIZE * j + k] * inv_N_val; - fp16_wrap(&tmp); - sqrVec[k] = tmp * X_arr.col(i)[VEC_SIZE * j + k] + sqrVec[k]; - fp16_wrap(&sqrVec[k]); + fake_fp16::fma_fp16( + VEC_SIZE, + &X_fp16[i * N + VEC_SIZE * j], + inv_N_vec.data(), + avgVec.data()); + for (int k = 0; k < VEC_SIZE; k++) { + inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * j + k] * inv_N_val; } + fbgemm::RoundToFloat16( + inv_N_prod_vec.data(), + inv_N_prod_vec.data(), + VEC_SIZE, + FLAGS_caffe2_fbgemm_fake_fp16_clamp); + + fake_fp16::fma_fp16( + VEC_SIZE, + &X_fp16[i * N + VEC_SIZE * j], + inv_N_prod_vec.data(), + sqrVec.data()); } - for (int k = 0; k < tailSize; ++k) { - avgVec[k] = X_arr.col(i)[VEC_SIZE * numVecs + k] * inv_N_val + avgVec[k]; - fp16_wrap(&avgVec[k]); - tmp = X_arr.col(i)[VEC_SIZE * numVecs + k] * inv_N_val; - fp16_wrap(&tmp); - sqrVec[k] = tmp * X_arr.col(i)[VEC_SIZE * numVecs + k] + sqrVec[k]; - fp16_wrap(&sqrVec[k]); + + if (tailSize > 0) { + fake_fp16::fma_fp16( + tailSize, + &X_fp16[i * N + VEC_SIZE * numVecs], + inv_N_vec.data(), + avgVec.data()); + for (int k = 0; k < tailSize; k++) { + inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * numVecs + k] * inv_N_val; + } + fbgemm::RoundToFloat16( + inv_N_prod_vec.data(), + inv_N_prod_vec.data(), + tailSize, + FLAGS_caffe2_fbgemm_fake_fp16_clamp); + + fake_fp16::fma_fp16( + tailSize, + &X_fp16[i * N + VEC_SIZE * numVecs], + inv_N_prod_vec.data(), + sqrVec.data()); } + mean[i] = ReducedAdd(avgVec); sqr[i] = ReducedAdd(sqrVec); // compute variance and std deviation - var[i] = -mean[i] * mean[i] + sqr[i]; - fp16_wrap(&var[i]); - tmp = var[i] + eps; + + float neg_mean = -mean[i]; + fake_fp16::fma_fp16(1, &mean[i], &neg_mean, &sqr[i]); + var[i] = sqr[i]; + + if (var[i] < 0.0) { + LOG(WARNING) << "Variance " << var[i] << "negative, resetting to 0."; + var[i] = 0.0; + } + + float teps = eps; + fp16_wrap(&teps); + tmp = var[i] + teps; fp16_wrap(&tmp); + if (tmp < 0) { + LOG(WARNING) << "Variance " << var[i] << "negative, resetting to 0."; + tmp = 0.0; + } std[i] = std::sqrt(tmp); fp16_wrap(&std[i]); } diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h index 3c285fb05c4..b703d1f2ad5 100644 --- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h +++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h @@ -51,14 +51,10 @@ class LayerNormFakeFp16Op final : public Operator { const int M = X.size_to_dim(canonical_axis); const int N = X.size_from_dim(canonical_axis); Y->ResizeLike(X); - scale_.Resize(M); - bias_.Resize(M); const T* X_data = X.template data(); T* Y_data = Y->template mutable_data(); T* mean_data = mean->template mutable_data(); T* sigma_data = sigma->template mutable_data(); - T* scale_data = scale_.template mutable_data(); - T* bias_data = bias_.template mutable_data(); std::vector X_rounded(X.numel()); fbgemm::RoundToFloat16( @@ -138,9 +134,6 @@ class LayerNormFakeFp16Op final : public Operator { const float epsilon_; const bool elementwise_affine_; - Tensor scale_{Context::GetDeviceType()}; - Tensor bias_{Context::GetDeviceType()}; - INPUT_TAGS(INPUT); OUTPUT_TAGS(OUTPUT, MEAN, STD); }; diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py index 920b3ffe1a6..129f9cd4bf5 100644 --- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py @@ -25,19 +25,19 @@ GLOW_LOWERED_BATCHNORM = False # Test the lowered LayerNorm op class LayerNorm(serial.SerializedTestCase): - @given(seed=st.integers(0, 65535)) - @settings(max_examples=10) - def test_layernorm(self, seed): + @given(seed=st.integers(0, 65535), + batch_size=st.integers(min_value=1, max_value=50), + size=st.integers(min_value=2, max_value=128), + epsilon=st.floats(min_value=1e-4, max_value=1e-3), + elementwise_affine=st.booleans()) + @settings(max_examples=100) + def test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine): np.random.seed(seed) # Reset the workspace - size = 4 - input_channels = 4 - batch_size = 1 - axis = 1 - epsilon = 1e-4 workspace.ResetWorkspace() + axis = 1 - dims = np.array(([batch_size, input_channels, size, size])) + dims = np.array(([batch_size, size])) X = np.random.uniform(size=dims).astype(np.float32) - 0.5 gamma = np.random.randn(*X.shape[axis:]).astype(np.float32) beta = np.random.randn(*X.shape[axis:]).astype(np.float32) @@ -49,11 +49,11 @@ class LayerNorm(serial.SerializedTestCase): pred_net.op.add().CopyFrom( core.CreateOperator( "LayerNorm", - ["X", "gamma", "beta"], + ["X", "gamma", "beta"] if elementwise_affine else ["X"], ["Y", "mean", "rstd"], - axis=1, + axis=axis, epsilon=epsilon, - elementwise_affine=True + elementwise_affine=elementwise_affine ) ) @@ -64,11 +64,11 @@ class LayerNorm(serial.SerializedTestCase): pred_net_ref.op.add().CopyFrom( core.CreateOperator( "LayerNormFakeFP16NNPI", - ["X", "gamma", "beta"], + ["X", "gamma", "beta"] if elementwise_affine else ["X"], ["Y", "mean", "rstd"], - axis=1, + axis=axis, epsilon=epsilon, - elementwise_affine=True + elementwise_affine=elementwise_affine ) ) @@ -94,6 +94,10 @@ class LayerNorm(serial.SerializedTestCase): workspace.RunNet(pred_net_ref.name) Y_c2 = workspace.FetchBlob("Y") + dims1 = np.array(([1, *dims])) + X_glow = X.reshape(dims1) + workspace.FeedBlob("X", X_glow) + workspace.RunNet(pred_net_onnxified.name) Y_glow = workspace.FetchBlob("Y") @@ -104,10 +108,11 @@ class LayerNorm(serial.SerializedTestCase): { "seed": seed, "size": size, - "input_channels": input_channels, "batch_size": batch_size, "epsilon": epsilon, - "axis": axis, + "gamma": gamma, + "beta": beta, + "elementwise_affine": elementwise_affine, "X": X, "Y_glow": Y_glow, "Y_c2": Y_c2,