LayerNormGrad function body and LayerNorm inference/body fix (#9160)

* Add function body for LayerNormGrad

* Fix LayerNorm schema for multiple normalization dims
This commit is contained in:
G. Ramalingam 2021-09-30 12:03:08 -07:00 committed by GitHub
parent e1b84eefcc
commit e79be39081
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 155 additions and 10 deletions

View file

@ -2441,13 +2441,15 @@ Example 4:
if (ctx.getNumOutputs() > 1) {
auto saved_mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape();
saved_mean_shape->CopyFrom(input_shape);
saved_mean_shape->mutable_dim(static_cast<int>(axis))->set_dim_value(1);
for (int d = static_cast<int>(axis); d < input_ndim; ++d)
saved_mean_shape->mutable_dim(d)->set_dim_value(1);
}
if (ctx.getNumOutputs() > 2) {
auto saved_inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape();
saved_inv_std_dev_shape->CopyFrom(input_shape);
saved_inv_std_dev_shape->mutable_dim(static_cast<int>(axis))->set_dim_value(1);
for (int d = static_cast<int>(axis); d < input_ndim; ++d)
saved_inv_std_dev_shape->mutable_dim(d)->set_dim_value(1);
}
})
.SetContextDependentFunctionBodyBuilder(
@ -2507,9 +2509,11 @@ Example 4:
{{"Deviation"}, "Sub", {"XU", "Mean2D"}},
{{"Normalized"}, "Div", {"Deviation", "StdDev"}},
{{"NormalizedT"}, "Cast", {"Normalized"}, {{"to", T}}},
{{"Scaled"}, "Mul", {"NormalizedT", "Scale"}}};
{{"Scale2D"}, "Flatten", {"Scale"}, {{"axis", int64_t(0)}}},
{{"Scaled"}, "Mul", {"NormalizedT", "Scale2D"}}};
if (ctx.hasInput(2)) {
body.push_back({{"Biased"}, "Add", {"Scaled", "B"}});
body.push_back({{"B2D"}, "Flatten", {"B"}, {{"axis", int64_t(0)}}});
body.push_back({{"Biased"}, "Add", {"Scaled", "B2D"}});
} else {
body.push_back({{"Biased"}, "Identity", {"Scaled"}});
}

View file

@ -9,6 +9,7 @@
#include "onnx/onnx-operators_pb.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/function.h"
#include "onnx/defs/parser.h"
namespace ONNX_NAMESPACE {
@ -20,4 +21,55 @@ inline static FunctionBodyHelper::NodeDef Const(const std::string& name, double
return FunctionBodyHelper::NodeDef{
{name}, "Constant", {}, {{"value", ToTensor(value, elem_type)}}};
}
class FunctionBuilder {
public:
FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
FunctionBuilder& Add(const char* nodes_txt) {
OnnxParser parser(nodes_txt);
auto& nodes = *funProto.mutable_node();
while (!parser.EndOfInput()) {
auto status = parser.Parse(*nodes.Add());
if (!status.IsOK())
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
return *this;
}
FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
OnnxParser parser(node_txt);
auto& node = *funProto.add_node();
auto status = parser.Parse(node);
if (!status.IsOK()) {
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
if (!parser.EndOfInput()) {
ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
}
*node.add_attribute() = attr;
return *this;
}
template <typename T>
FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, T attr_value) {
return Add (node_txt, MakeAttribute(attr_name, attr_value));
}
FunctionBuilder& AddOpset(const char* domain, int version) {
auto* opset = funProto.add_opset_import();
opset->set_domain(domain);
opset->set_version(version);
return *this;
}
private:
FunctionProto& funProto;
};
} // namespace ONNX_NAMESPACE

View file

@ -27,10 +27,8 @@ class ContribFunExpansionTest : public ::testing::Test {
};
template <typename T, typename U, bool RunTest>
void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) {
void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true, std::vector<int64_t> shape1 = {8, 16}, std::vector<int64_t> shape2 = {16}, int64_t axis = -1) {
FunctionTestCase testCase("LayerNormalization", kOnnxDomain);
std::vector<int64_t> shape1{8, 16};
std::vector<int64_t> shape2{16};
testCase.AddInput<T, RunTest>("x", shape1);
testCase.AddInput<T, RunTest>("scale", shape2);
@ -39,6 +37,8 @@ void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) {
testCase.AddOutput(compute_mean ? "mean" : "");
testCase.AddOutput(compute_isd ? "invstddev" : "");
testCase.AddAttribute("stash_type", data_types_internal::ToTensorDataType<U>());
if (axis != -1)
testCase.AddAttribute("axis", axis);
if (RunTest)
testCase.RunTest();
else
@ -59,6 +59,11 @@ TEST_F(ContribFunExpansionTest, LayerNorm_OptionalOutputs) {
CheckLayerNorm<float, float, true>(true, false);
}
TEST_F(ContribFunExpansionTest, LayerNorm_OtherShapes) {
// Test expand-and-run
CheckLayerNorm<float, float, true>(true, true, {4, 2, 8}, {2, 8}, 1);
}
template <typename T>
void CheckGelu() {
FunctionTestCase testCase("Gelu", kMSDomain);

View file

@ -2095,7 +2095,7 @@ Example 4:
.Input(1, "X", "Input data tensor from the forward path", "T")
.Input(2, "scale", "Scale tensor.", "T")
.Input(3, "mean", "mean of X.", "U")
.Input(4, "inv_std_var", "inverse std variance of X.", "U")
.Input(4, "inv_std_dev", "inverse std deviation of X.", "U")
.Output(0, "X_grad", "Gradient of the input.", "T")
.Output(1, "scale_grad", "Gradient of the scale.", "T")
.Output(2, "bias_grad", "Gradient of the bias.", "T")
@ -2115,7 +2115,59 @@ Example 4:
// The bias tensor has the same shape of the scale tensor.
propagateElemTypeFromInputToOutput(ctx, 2, 2);
propagateShapeFromInputToOutput(ctx, 2, 2);
});
})
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
FunctionBuilder builder(functionProto);
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
int64_t T = tp->tensor_type().elem_type();
// Requirements/assumptions:
// Inputs Y_grad and X are of shape [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]] and type T
// Input scale is of shape [d[axis], ..., d[rank-1]] and type U
// Inputs mean and inv_std_dev are of shape [d[0], ..., d[axis-1], 1, ..., 1] (same rank as X)
// and type U.
//
auto axis_ref_attr = MakeRefAttribute("axis", AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
builder
.AddOpset("", 15)
.Add("cast_mean = Cast (mean)", "to", T)
.Add("cast_inv_std_dev = Cast(inv_std_dev)", "to", T)
.Add("x_2d = Flatten (X)", axis_ref_attr)
.Add("Y_grad_2d = Flatten (Y_grad)", axis_ref_attr)
.Add("mean_2d = Flatten (cast_mean)", axis_ref_attr)
.Add("inv_std_dev_2d = Flatten (cast_inv_std_dev)", axis_ref_attr)
.Add(R"ONNX(
shape_x = Shape (X)
bias_scale_shape = Shape (scale)
scale_2d = Flatten <axis = 0> (scale)
axis_0 = Constant <value = int64[1] {0}> ()
bias_grad_2d = ReduceSum (Y_grad_2d, axis_0)
bias_grad = Reshape (bias_grad_2d, bias_scale_shape)
deviation = Sub (x_2d, mean_2d)
normalized_deviation = Mul(deviation, inv_std_dev_2d)
scale_grad_rows = Mul (Y_grad_2d, normalized_deviation)
scale_grad_2d = ReduceSum (scale_grad_rows, axis_0)
scale_grad = Reshape (scale_grad_2d, bias_scale_shape)
normalized_layer_grad = Mul (Y_grad_2d, scale_2d)
B = Mul (normalized_layer_grad, inv_std_dev_2d)
C = Mul (B, normalized_deviation)
mean_B = ReduceMean <axes = [1]> (B)
mean_C = ReduceMean <axes = [1]> (C)
nd_mean_C = Mul (normalized_deviation, mean_C)
mean_diff_B = Sub (B, mean_B)
X_grad_2D = Sub (mean_diff_B, nd_mean_C)
X_grad = Reshape (X_grad_2D, shape_x)
)ONNX");
schema.BuildFunction(functionProto);
return true;
});
ONNX_CONTRIB_OPERATOR_SCHEMA(SimplifiedLayerNormalizationGrad)
.SetDomain(kMSDomain)

View file

@ -171,7 +171,6 @@ TEST_F(FunExpansionTest, DropoutGrad_WithRatio2) {
template <typename T, bool RunTest = true>
void TestUnaryOpGrad(const char* opname) {
FunctionTestCase testCase(opname);
std::vector<int64_t> shape{16, 4};
testCase.AddInput<T, RunTest>("dY", shape);
@ -199,5 +198,38 @@ TEST_F(FunExpansionTest, FastGeluGrad) {
TestUnaryOpGrad<MLFloat16, false>("FastGeluGrad");
}
template <typename T, typename U, bool RunTest = true>
void TestLayerNormGrad(std::vector<int64_t> prefix_shape, std::vector<int64_t> suffix_shape) {
FunctionTestCase testCase("LayerNormalizationGrad");
std::vector<int64_t> input_shape(prefix_shape);
for (auto d : suffix_shape)
input_shape.push_back(d);
std::vector<int64_t> stats_shape(prefix_shape);
for (auto d : suffix_shape) {
(void)d;
stats_shape.push_back(1);
}
testCase.AddInput<T, RunTest>("Y_grad", input_shape);
testCase.AddInput<T, RunTest>("X", input_shape);
testCase.AddInput<T, RunTest>("scale", suffix_shape);
testCase.AddInput<U, RunTest>("mean", stats_shape);
testCase.AddInput<U, RunTest>("inv_std_dev", stats_shape);
testCase.AddOutput("X_grad");
testCase.AddOutput("scale_grad");
testCase.AddOutput("bias_grad");
testCase.AddAttribute("axis", prefix_shape.size());
if (RunTest)
testCase.RunTest();
else
// Test only expanded model creation and model checking.
testCase.CreateModel(true);
}
TEST_F(FunExpansionTest, LayerNormalizationGrad) {
TestLayerNormGrad<float, float, true>({4, 1}, {8, 4});
TestLayerNormGrad<float, float, true>({}, {8, 4});
TestLayerNormGrad<BFloat16, float, false>({}, {8, 4});
}
} // namespace test
} // namespace onnxruntime