mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fixes to layernorm emulation (#40422)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40422 fix the remaining differences to the emulation of fp16 layernorm Test Plan: unit test of layernorm Reviewed By: venkatacrc Differential Revision: D22182849 fbshipit-source-id: 8a45c21418517d65d7a41663d5ad2110d6b4677a
This commit is contained in:
parent
b82bd654cc
commit
27982d5711
3 changed files with 106 additions and 63 deletions
|
|
@ -1,5 +1,6 @@
|
|||
#include "layernorm_fp16_fake_op.h"
|
||||
#include "caffe2/contrib/fakelowp/common.h"
|
||||
#include "caffe2/contrib/fakelowp/fp16_fma.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
|
@ -20,33 +21,30 @@ void LayerNormFakeFp16Op<CPUContext>::calcY(
|
|||
EigenArrayMap<T> Y_arr(Y, N, M);
|
||||
T tmp = T(0);
|
||||
|
||||
for (int i = 0; i < M; ++i) {
|
||||
T normFactor = T(T(1) / std_arr[i]);
|
||||
fp16_wrap(&normFactor);
|
||||
for (int j = 0; j < N; ++j) {
|
||||
T normalized = T(X_arr.col(i)[j] - mean[i]);
|
||||
fp16_wrap(&normalized);
|
||||
normalized *= normFactor;
|
||||
fp16_wrap(&normalized);
|
||||
Y_arr.col(i)[j] = normalized;
|
||||
}
|
||||
}
|
||||
|
||||
if (gamma != nullptr && beta != nullptr) {
|
||||
ConstEigenVectorArrayMap<T> gamma_arr(gamma, N);
|
||||
ConstEigenVectorArrayMap<T> beta_arr(beta, N);
|
||||
|
||||
for (int i = 0; i < M; ++i) {
|
||||
T normFactor = T(T(1) / std_arr[i]);
|
||||
fp16_wrap(&normFactor);
|
||||
for (int j = 0; j < N; ++j) {
|
||||
tmp = T(X_arr.col(i)[j] - mean[i]);
|
||||
fp16_wrap(&tmp);
|
||||
T normalized = tmp * normFactor;
|
||||
fp16_wrap(&normalized);
|
||||
tmp = normalized * gamma_arr[j] + beta_arr[j];
|
||||
fp16_wrap(&tmp);
|
||||
Y_arr.col(i)[j] = tmp;
|
||||
vector<float> res(N);
|
||||
for (int j = 0; j < N; j++) {
|
||||
res[j] = beta[j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < M; ++i) {
|
||||
T normFactor = T(T(1) / std_arr[i]);
|
||||
fp16_wrap(&normFactor);
|
||||
for (int j = 0; j < N; ++j) {
|
||||
tmp = T(X_arr.col(i)[j] - mean[i]);
|
||||
fp16_wrap(&tmp);
|
||||
tmp *= normFactor;
|
||||
fp16_wrap(&tmp);
|
||||
Y_arr.col(i)[j] = tmp;
|
||||
fake_fp16::fma_fp16(N, &Y_arr.col(i)[0], gamma, res.data());
|
||||
for (int j = 0; j < N; j++) {
|
||||
Y_arr.col(i)[j] = res[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -98,39 +96,86 @@ void LayerNormFakeFp16Op<CPUContext>::calcMeanStd(
|
|||
fp16_wrap(&inv_N_val);
|
||||
T tmp = T(0);
|
||||
|
||||
const int VEC_SIZE = 32;
|
||||
constexpr int VEC_SIZE = 32;
|
||||
std::vector<T> inv_N_vec(VEC_SIZE, inv_N_val);
|
||||
std::vector<T> inv_N_prod_vec(VEC_SIZE, 0);
|
||||
std::vector<T> avgVec(VEC_SIZE, T(0));
|
||||
std::vector<T> sqrVec(VEC_SIZE, T(0));
|
||||
int numVecs = N / VEC_SIZE;
|
||||
int tailSize = N - (numVecs * VEC_SIZE);
|
||||
|
||||
vector<T> X_fp16(M * N);
|
||||
fbgemm::RoundToFloat16(
|
||||
X, X_fp16.data(), M * N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
|
||||
|
||||
for (int i = 0; i < M; ++i) {
|
||||
std::fill(avgVec.begin(), avgVec.end(), T(0));
|
||||
std::fill(sqrVec.begin(), sqrVec.end(), T(0));
|
||||
for (int j = 0; j < numVecs; ++j) {
|
||||
for (int k = 0; k < VEC_SIZE; ++k) {
|
||||
avgVec[k] = X_arr.col(i)[VEC_SIZE * j + k] * inv_N_val + avgVec[k];
|
||||
fp16_wrap(&avgVec[k]);
|
||||
tmp = X_arr.col(i)[VEC_SIZE * j + k] * inv_N_val;
|
||||
fp16_wrap(&tmp);
|
||||
sqrVec[k] = tmp * X_arr.col(i)[VEC_SIZE * j + k] + sqrVec[k];
|
||||
fp16_wrap(&sqrVec[k]);
|
||||
fake_fp16::fma_fp16(
|
||||
VEC_SIZE,
|
||||
&X_fp16[i * N + VEC_SIZE * j],
|
||||
inv_N_vec.data(),
|
||||
avgVec.data());
|
||||
for (int k = 0; k < VEC_SIZE; k++) {
|
||||
inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * j + k] * inv_N_val;
|
||||
}
|
||||
fbgemm::RoundToFloat16(
|
||||
inv_N_prod_vec.data(),
|
||||
inv_N_prod_vec.data(),
|
||||
VEC_SIZE,
|
||||
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
|
||||
|
||||
fake_fp16::fma_fp16(
|
||||
VEC_SIZE,
|
||||
&X_fp16[i * N + VEC_SIZE * j],
|
||||
inv_N_prod_vec.data(),
|
||||
sqrVec.data());
|
||||
}
|
||||
for (int k = 0; k < tailSize; ++k) {
|
||||
avgVec[k] = X_arr.col(i)[VEC_SIZE * numVecs + k] * inv_N_val + avgVec[k];
|
||||
fp16_wrap(&avgVec[k]);
|
||||
tmp = X_arr.col(i)[VEC_SIZE * numVecs + k] * inv_N_val;
|
||||
fp16_wrap(&tmp);
|
||||
sqrVec[k] = tmp * X_arr.col(i)[VEC_SIZE * numVecs + k] + sqrVec[k];
|
||||
fp16_wrap(&sqrVec[k]);
|
||||
|
||||
if (tailSize > 0) {
|
||||
fake_fp16::fma_fp16(
|
||||
tailSize,
|
||||
&X_fp16[i * N + VEC_SIZE * numVecs],
|
||||
inv_N_vec.data(),
|
||||
avgVec.data());
|
||||
for (int k = 0; k < tailSize; k++) {
|
||||
inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * numVecs + k] * inv_N_val;
|
||||
}
|
||||
fbgemm::RoundToFloat16(
|
||||
inv_N_prod_vec.data(),
|
||||
inv_N_prod_vec.data(),
|
||||
tailSize,
|
||||
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
|
||||
|
||||
fake_fp16::fma_fp16(
|
||||
tailSize,
|
||||
&X_fp16[i * N + VEC_SIZE * numVecs],
|
||||
inv_N_prod_vec.data(),
|
||||
sqrVec.data());
|
||||
}
|
||||
|
||||
mean[i] = ReducedAdd(avgVec);
|
||||
sqr[i] = ReducedAdd(sqrVec);
|
||||
// compute variance and std deviation
|
||||
var[i] = -mean[i] * mean[i] + sqr[i];
|
||||
fp16_wrap(&var[i]);
|
||||
tmp = var[i] + eps;
|
||||
|
||||
float neg_mean = -mean[i];
|
||||
fake_fp16::fma_fp16(1, &mean[i], &neg_mean, &sqr[i]);
|
||||
var[i] = sqr[i];
|
||||
|
||||
if (var[i] < 0.0) {
|
||||
LOG(WARNING) << "Variance " << var[i] << "negative, resetting to 0.";
|
||||
var[i] = 0.0;
|
||||
}
|
||||
|
||||
float teps = eps;
|
||||
fp16_wrap(&teps);
|
||||
tmp = var[i] + teps;
|
||||
fp16_wrap(&tmp);
|
||||
if (tmp < 0) {
|
||||
LOG(WARNING) << "Variance " << var[i] << "negative, resetting to 0.";
|
||||
tmp = 0.0;
|
||||
}
|
||||
std[i] = std::sqrt(tmp);
|
||||
fp16_wrap(&std[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,14 +51,10 @@ class LayerNormFakeFp16Op final : public Operator<Context> {
|
|||
const int M = X.size_to_dim(canonical_axis);
|
||||
const int N = X.size_from_dim(canonical_axis);
|
||||
Y->ResizeLike(X);
|
||||
scale_.Resize(M);
|
||||
bias_.Resize(M);
|
||||
const T* X_data = X.template data<T>();
|
||||
T* Y_data = Y->template mutable_data<T>();
|
||||
T* mean_data = mean->template mutable_data<T>();
|
||||
T* sigma_data = sigma->template mutable_data<T>();
|
||||
T* scale_data = scale_.template mutable_data<T>();
|
||||
T* bias_data = bias_.template mutable_data<T>();
|
||||
|
||||
std::vector<float> X_rounded(X.numel());
|
||||
fbgemm::RoundToFloat16(
|
||||
|
|
@ -138,9 +134,6 @@ class LayerNormFakeFp16Op final : public Operator<Context> {
|
|||
const float epsilon_;
|
||||
const bool elementwise_affine_;
|
||||
|
||||
Tensor scale_{Context::GetDeviceType()};
|
||||
Tensor bias_{Context::GetDeviceType()};
|
||||
|
||||
INPUT_TAGS(INPUT);
|
||||
OUTPUT_TAGS(OUTPUT, MEAN, STD);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -25,19 +25,19 @@ GLOW_LOWERED_BATCHNORM = False
|
|||
# Test the lowered LayerNorm op
|
||||
class LayerNorm(serial.SerializedTestCase):
|
||||
|
||||
@given(seed=st.integers(0, 65535))
|
||||
@settings(max_examples=10)
|
||||
def test_layernorm(self, seed):
|
||||
@given(seed=st.integers(0, 65535),
|
||||
batch_size=st.integers(min_value=1, max_value=50),
|
||||
size=st.integers(min_value=2, max_value=128),
|
||||
epsilon=st.floats(min_value=1e-4, max_value=1e-3),
|
||||
elementwise_affine=st.booleans())
|
||||
@settings(max_examples=100)
|
||||
def test_layernorm(self, seed, batch_size, size, epsilon, elementwise_affine):
|
||||
np.random.seed(seed)
|
||||
# Reset the workspace
|
||||
size = 4
|
||||
input_channels = 4
|
||||
batch_size = 1
|
||||
axis = 1
|
||||
epsilon = 1e-4
|
||||
workspace.ResetWorkspace()
|
||||
axis = 1
|
||||
|
||||
dims = np.array(([batch_size, input_channels, size, size]))
|
||||
dims = np.array(([batch_size, size]))
|
||||
X = np.random.uniform(size=dims).astype(np.float32) - 0.5
|
||||
gamma = np.random.randn(*X.shape[axis:]).astype(np.float32)
|
||||
beta = np.random.randn(*X.shape[axis:]).astype(np.float32)
|
||||
|
|
@ -49,11 +49,11 @@ class LayerNorm(serial.SerializedTestCase):
|
|||
pred_net.op.add().CopyFrom(
|
||||
core.CreateOperator(
|
||||
"LayerNorm",
|
||||
["X", "gamma", "beta"],
|
||||
["X", "gamma", "beta"] if elementwise_affine else ["X"],
|
||||
["Y", "mean", "rstd"],
|
||||
axis=1,
|
||||
axis=axis,
|
||||
epsilon=epsilon,
|
||||
elementwise_affine=True
|
||||
elementwise_affine=elementwise_affine
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -64,11 +64,11 @@ class LayerNorm(serial.SerializedTestCase):
|
|||
pred_net_ref.op.add().CopyFrom(
|
||||
core.CreateOperator(
|
||||
"LayerNormFakeFP16NNPI",
|
||||
["X", "gamma", "beta"],
|
||||
["X", "gamma", "beta"] if elementwise_affine else ["X"],
|
||||
["Y", "mean", "rstd"],
|
||||
axis=1,
|
||||
axis=axis,
|
||||
epsilon=epsilon,
|
||||
elementwise_affine=True
|
||||
elementwise_affine=elementwise_affine
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -94,6 +94,10 @@ class LayerNorm(serial.SerializedTestCase):
|
|||
workspace.RunNet(pred_net_ref.name)
|
||||
Y_c2 = workspace.FetchBlob("Y")
|
||||
|
||||
dims1 = np.array(([1, *dims]))
|
||||
X_glow = X.reshape(dims1)
|
||||
workspace.FeedBlob("X", X_glow)
|
||||
|
||||
workspace.RunNet(pred_net_onnxified.name)
|
||||
Y_glow = workspace.FetchBlob("Y")
|
||||
|
||||
|
|
@ -104,10 +108,11 @@ class LayerNorm(serial.SerializedTestCase):
|
|||
{
|
||||
"seed": seed,
|
||||
"size": size,
|
||||
"input_channels": input_channels,
|
||||
"batch_size": batch_size,
|
||||
"epsilon": epsilon,
|
||||
"axis": axis,
|
||||
"gamma": gamma,
|
||||
"beta": beta,
|
||||
"elementwise_affine": elementwise_affine,
|
||||
"X": X,
|
||||
"Y_glow": Y_glow,
|
||||
"Y_c2": Y_c2,
|
||||
|
|
|
|||
Loading…
Reference in a new issue