fixes to layernorm emulation (#40422)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40422

fix the remaining differences to the emulation of fp16 layernorm

Test Plan: unit test of layernorm

Reviewed By: venkatacrc

Differential Revision: D22182849

fbshipit-source-id: 8a45c21418517d65d7a41663d5ad2110d6b4677a
This commit is contained in:
Hector Yuen 2020-06-23 11:48:27 -07:00 committed by Facebook GitHub Bot
parent b82bd654cc
commit 27982d5711
3 changed files with 106 additions and 63 deletions

View file

@ -1,5 +1,6 @@
#include "layernorm_fp16_fake_op.h"
#include "caffe2/contrib/fakelowp/common.h"
#include "caffe2/contrib/fakelowp/fp16_fma.h"
namespace caffe2 {
@ -20,33 +21,30 @@ void LayerNormFakeFp16Op<CPUContext>::calcY(
EigenArrayMap<T> Y_arr(Y, N, M);
T tmp = T(0);
for (int i = 0; i < M; ++i) {
T normFactor = T(T(1) / std_arr[i]);
fp16_wrap(&normFactor);
for (int j = 0; j < N; ++j) {
T normalized = T(X_arr.col(i)[j] - mean[i]);
fp16_wrap(&normalized);
normalized *= normFactor;
fp16_wrap(&normalized);
Y_arr.col(i)[j] = normalized;
}
}
if (gamma != nullptr && beta != nullptr) {
ConstEigenVectorArrayMap<T> gamma_arr(gamma, N);
ConstEigenVectorArrayMap<T> beta_arr(beta, N);
for (int i = 0; i < M; ++i) {
T normFactor = T(T(1) / std_arr[i]);
fp16_wrap(&normFactor);
for (int j = 0; j < N; ++j) {
tmp = T(X_arr.col(i)[j] - mean[i]);
fp16_wrap(&tmp);
T normalized = tmp * normFactor;
fp16_wrap(&normalized);
tmp = normalized * gamma_arr[j] + beta_arr[j];
fp16_wrap(&tmp);
Y_arr.col(i)[j] = tmp;
vector<float> res(N);
for (int j = 0; j < N; j++) {
res[j] = beta[j];
}
}
} else {
for (int i = 0; i < M; ++i) {
T normFactor = T(T(1) / std_arr[i]);
fp16_wrap(&normFactor);
for (int j = 0; j < N; ++j) {
tmp = T(X_arr.col(i)[j] - mean[i]);
fp16_wrap(&tmp);
tmp *= normFactor;
fp16_wrap(&tmp);
Y_arr.col(i)[j] = tmp;
fake_fp16::fma_fp16(N, &Y_arr.col(i)[0], gamma, res.data());
for (int j = 0; j < N; j++) {
Y_arr.col(i)[j] = res[j];
}
}
}
@ -98,39 +96,86 @@ void LayerNormFakeFp16Op<CPUContext>::calcMeanStd(
fp16_wrap(&inv_N_val);
T tmp = T(0);
const int VEC_SIZE = 32;
constexpr int VEC_SIZE = 32;
std::vector<T> inv_N_vec(VEC_SIZE, inv_N_val);
std::vector<T> inv_N_prod_vec(VEC_SIZE, 0);
std::vector<T> avgVec(VEC_SIZE, T(0));
std::vector<T> sqrVec(VEC_SIZE, T(0));
int numVecs = N / VEC_SIZE;
int tailSize = N - (numVecs * VEC_SIZE);
vector<T> X_fp16(M * N);
fbgemm::RoundToFloat16(
X, X_fp16.data(), M * N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
for (int i = 0; i < M; ++i) {
std::fill(avgVec.begin(), avgVec.end(), T(0));
std::fill(sqrVec.begin(), sqrVec.end(), T(0));
for (int j = 0; j < numVecs; ++j) {
for (int k = 0; k < VEC_SIZE; ++k) {
avgVec[k] = X_arr.col(i)[VEC_SIZE * j + k] * inv_N_val + avgVec[k];
fp16_wrap(&avgVec[k]);
tmp = X_arr.col(i)[VEC_SIZE * j + k] * inv_N_val;
fp16_wrap(&tmp);
sqrVec[k] = tmp * X_arr.col(i)[VEC_SIZE * j + k] + sqrVec[k];
fp16_wrap(&sqrVec[k]);
fake_fp16::fma_fp16(
VEC_SIZE,
&X_fp16[i * N + VEC_SIZE * j],
inv_N_vec.data(),
avgVec.data());
for (int k = 0; k < VEC_SIZE; k++) {
inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * j + k] * inv_N_val;
}
fbgemm::RoundToFloat16(
inv_N_prod_vec.data(),
inv_N_prod_vec.data(),
VEC_SIZE,
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
fake_fp16::fma_fp16(
VEC_SIZE,
&X_fp16[i * N + VEC_SIZE * j],
inv_N_prod_vec.data(),
sqrVec.data());
}
for (int k = 0; k < tailSize; ++k) {
avgVec[k] = X_arr.col(i)[VEC_SIZE * numVecs + k] * inv_N_val + avgVec[k];
fp16_wrap(&avgVec[k]);
tmp = X_arr.col(i)[VEC_SIZE * numVecs + k] * inv_N_val;
fp16_wrap(&tmp);
sqrVec[k] = tmp * X_arr.col(i)[VEC_SIZE * numVecs + k] + sqrVec[k];
fp16_wrap(&sqrVec[k]);
if (tailSize > 0) {
fake_fp16::fma_fp16(
tailSize,
&X_fp16[i * N + VEC_SIZE * numVecs],
inv_N_vec.data(),
avgVec.data());
for (int k = 0; k < tailSize; k++) {
inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * numVecs + k] * inv_N_val;
}
fbgemm::RoundToFloat16(
inv_N_prod_vec.data(),
inv_N_prod_vec.data(),
tailSize,
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
fake_fp16::fma_fp16(
tailSize,
&X_fp16[i * N + VEC_SIZE * numVecs],
inv_N_prod_vec.data(),
sqrVec.data());
}
mean[i] = ReducedAdd(avgVec);
sqr[i] = ReducedAdd(sqrVec);
// compute variance and std deviation
var[i] = -mean[i] * mean[i] + sqr[i];
fp16_wrap(&var[i]);
tmp = var[i] + eps;
float neg_mean = -mean[i];
fake_fp16::fma_fp16(1, &mean[i], &neg_mean, &sqr[i]);
var[i] = sqr[i];
if (var[i] < 0.0) {
LOG(WARNING) << "Variance " << var[i] << "negative, resetting to 0.";
var[i] = 0.0;
}
float teps = eps;
fp16_wrap(&teps);
tmp = var[i] + teps;
fp16_wrap(&tmp);
if (tmp < 0) {
LOG(WARNING) << "Variance " << var[i] << "negative, resetting to 0.";
tmp = 0.0;
}
std[i] = std::sqrt(tmp);
fp16_wrap(&std[i]);
}

View file

@ -51,14 +51,10 @@ class LayerNormFakeFp16Op final : public Operator<Context> {
const int M = X.size_to_dim(canonical_axis);
const int N = X.size_from_dim(canonical_axis);
Y->ResizeLike(X);
scale_.Resize(M);
bias_.Resize(M);
const T* X_data = X.template data<T>();
T* Y_data = Y->template mutable_data<T>();
T* mean_data = mean->template mutable_data<T>();
T* sigma_data = sigma->template mutable_data<T>();
T* scale_data = scale_.template mutable_data<T>();
T* bias_data = bias_.template mutable_data<T>();
std::vector<float> X_rounded(X.numel());
fbgemm::RoundToFloat16(
@ -138,9 +134,6 @@ class LayerNormFakeFp16Op final : public Operator<Context> {
const float epsilon_;
const bool elementwise_affine_;
Tensor scale_{Context::GetDeviceType()};
Tensor bias_{Context::GetDeviceType()};
INPUT_TAGS(INPUT);
OUTPUT_TAGS(OUTPUT, MEAN, STD);
};

View file

@ -25,19 +25,19 @@ GLOW_LOWERED_BATCHNORM = False
# Test the lowered LayerNorm op
class LayerNorm(serial.SerializedTestCase):
@given(seed=st.integers(0, 65535))
@settings(max_examples=10)
def test_layernorm(self, seed):
@given(seed=st.integers(0, 65535),
batch_size=st.integers(min_value=1, max_value=50),
size=st.integers(min_value=2, max_value=128),
epsilon=st.floats(min_value=1e-4, max_value=1e-3),
elementwise_affine=st.booleans())
@settings(max_examples=100)
def test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine):
np.random.seed(seed)
# Reset the workspace
size = 4
input_channels = 4
batch_size = 1
axis = 1
epsilon = 1e-4
workspace.ResetWorkspace()
axis = 1
dims = np.array(([batch_size, input_channels, size, size]))
dims = np.array(([batch_size, size]))
X = np.random.uniform(size=dims).astype(np.float32) - 0.5
gamma = np.random.randn(*X.shape[axis:]).astype(np.float32)
beta = np.random.randn(*X.shape[axis:]).astype(np.float32)
@ -49,11 +49,11 @@ class LayerNorm(serial.SerializedTestCase):
pred_net.op.add().CopyFrom(
core.CreateOperator(
"LayerNorm",
["X", "gamma", "beta"],
["X", "gamma", "beta"] if elementwise_affine else ["X"],
["Y", "mean", "rstd"],
axis=1,
axis=axis,
epsilon=epsilon,
elementwise_affine=True
elementwise_affine=elementwise_affine
)
)
@ -64,11 +64,11 @@ class LayerNorm(serial.SerializedTestCase):
pred_net_ref.op.add().CopyFrom(
core.CreateOperator(
"LayerNormFakeFP16NNPI",
["X", "gamma", "beta"],
["X", "gamma", "beta"] if elementwise_affine else ["X"],
["Y", "mean", "rstd"],
axis=1,
axis=axis,
epsilon=epsilon,
elementwise_affine=True
elementwise_affine=elementwise_affine
)
)
@ -94,6 +94,10 @@ class LayerNorm(serial.SerializedTestCase):
workspace.RunNet(pred_net_ref.name)
Y_c2 = workspace.FetchBlob("Y")
dims1 = np.array(([1, *dims]))
X_glow = X.reshape(dims1)
workspace.FeedBlob("X", X_glow)
workspace.RunNet(pred_net_onnxified.name)
Y_glow = workspace.FetchBlob("Y")
@ -104,10 +108,11 @@ class LayerNorm(serial.SerializedTestCase):
{
"seed": seed,
"size": size,
"input_channels": input_channels,
"batch_size": batch_size,
"epsilon": epsilon,
"axis": axis,
"gamma": gamma,
"beta": beta,
"elementwise_affine": elementwise_affine,
"X": X,
"Y_glow": Y_glow,
"Y_c2": Y_c2,