mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
LayerNormGrad function body and LayerNorm inference/body fix (#9160)
* Add function body for LayerNormGrad * Fix LayerNorm schema for multiple normalization dims
This commit is contained in:
parent
e1b84eefcc
commit
e79be39081
5 changed files with 155 additions and 10 deletions
|
|
@ -2441,13 +2441,15 @@ Example 4:
|
|||
if (ctx.getNumOutputs() > 1) {
|
||||
auto saved_mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape();
|
||||
saved_mean_shape->CopyFrom(input_shape);
|
||||
saved_mean_shape->mutable_dim(static_cast<int>(axis))->set_dim_value(1);
|
||||
for (int d = static_cast<int>(axis); d < input_ndim; ++d)
|
||||
saved_mean_shape->mutable_dim(d)->set_dim_value(1);
|
||||
}
|
||||
|
||||
if (ctx.getNumOutputs() > 2) {
|
||||
auto saved_inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape();
|
||||
saved_inv_std_dev_shape->CopyFrom(input_shape);
|
||||
saved_inv_std_dev_shape->mutable_dim(static_cast<int>(axis))->set_dim_value(1);
|
||||
for (int d = static_cast<int>(axis); d < input_ndim; ++d)
|
||||
saved_inv_std_dev_shape->mutable_dim(d)->set_dim_value(1);
|
||||
}
|
||||
})
|
||||
.SetContextDependentFunctionBodyBuilder(
|
||||
|
|
@ -2507,9 +2509,11 @@ Example 4:
|
|||
{{"Deviation"}, "Sub", {"XU", "Mean2D"}},
|
||||
{{"Normalized"}, "Div", {"Deviation", "StdDev"}},
|
||||
{{"NormalizedT"}, "Cast", {"Normalized"}, {{"to", T}}},
|
||||
{{"Scaled"}, "Mul", {"NormalizedT", "Scale"}}};
|
||||
{{"Scale2D"}, "Flatten", {"Scale"}, {{"axis", int64_t(0)}}},
|
||||
{{"Scaled"}, "Mul", {"NormalizedT", "Scale2D"}}};
|
||||
if (ctx.hasInput(2)) {
|
||||
body.push_back({{"Biased"}, "Add", {"Scaled", "B"}});
|
||||
body.push_back({{"B2D"}, "Flatten", {"B"}, {{"axis", int64_t(0)}}});
|
||||
body.push_back({{"Biased"}, "Add", {"Scaled", "B2D"}});
|
||||
} else {
|
||||
body.push_back({{"Biased"}, "Identity", {"Scaled"}});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include "onnx/onnx-operators_pb.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include "onnx/defs/function.h"
|
||||
#include "onnx/defs/parser.h"
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
|
||||
|
|
@ -20,4 +21,55 @@ inline static FunctionBodyHelper::NodeDef Const(const std::string& name, double
|
|||
return FunctionBodyHelper::NodeDef{
|
||||
{name}, "Constant", {}, {{"value", ToTensor(value, elem_type)}}};
|
||||
}
|
||||
|
||||
class FunctionBuilder {
|
||||
public:
|
||||
FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
|
||||
|
||||
FunctionBuilder& Add(const char* nodes_txt) {
|
||||
OnnxParser parser(nodes_txt);
|
||||
auto& nodes = *funProto.mutable_node();
|
||||
|
||||
while (!parser.EndOfInput()) {
|
||||
auto status = parser.Parse(*nodes.Add());
|
||||
if (!status.IsOK())
|
||||
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
|
||||
OnnxParser parser(node_txt);
|
||||
auto& node = *funProto.add_node();
|
||||
auto status = parser.Parse(node);
|
||||
if (!status.IsOK()) {
|
||||
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
|
||||
}
|
||||
|
||||
if (!parser.EndOfInput()) {
|
||||
ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
|
||||
}
|
||||
|
||||
*node.add_attribute() = attr;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, T attr_value) {
|
||||
return Add (node_txt, MakeAttribute(attr_name, attr_value));
|
||||
}
|
||||
|
||||
FunctionBuilder& AddOpset(const char* domain, int version) {
|
||||
auto* opset = funProto.add_opset_import();
|
||||
opset->set_domain(domain);
|
||||
opset->set_version(version);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
FunctionProto& funProto;
|
||||
};
|
||||
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
|
@ -27,10 +27,8 @@ class ContribFunExpansionTest : public ::testing::Test {
|
|||
};
|
||||
|
||||
template <typename T, typename U, bool RunTest>
|
||||
void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) {
|
||||
void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true, std::vector<int64_t> shape1 = {8, 16}, std::vector<int64_t> shape2 = {16}, int64_t axis = -1) {
|
||||
FunctionTestCase testCase("LayerNormalization", kOnnxDomain);
|
||||
std::vector<int64_t> shape1{8, 16};
|
||||
std::vector<int64_t> shape2{16};
|
||||
|
||||
testCase.AddInput<T, RunTest>("x", shape1);
|
||||
testCase.AddInput<T, RunTest>("scale", shape2);
|
||||
|
|
@ -39,6 +37,8 @@ void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) {
|
|||
testCase.AddOutput(compute_mean ? "mean" : "");
|
||||
testCase.AddOutput(compute_isd ? "invstddev" : "");
|
||||
testCase.AddAttribute("stash_type", data_types_internal::ToTensorDataType<U>());
|
||||
if (axis != -1)
|
||||
testCase.AddAttribute("axis", axis);
|
||||
if (RunTest)
|
||||
testCase.RunTest();
|
||||
else
|
||||
|
|
@ -59,6 +59,11 @@ TEST_F(ContribFunExpansionTest, LayerNorm_OptionalOutputs) {
|
|||
CheckLayerNorm<float, float, true>(true, false);
|
||||
}
|
||||
|
||||
TEST_F(ContribFunExpansionTest, LayerNorm_OtherShapes) {
|
||||
// Test expand-and-run
|
||||
CheckLayerNorm<float, float, true>(true, true, {4, 2, 8}, {2, 8}, 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckGelu() {
|
||||
FunctionTestCase testCase("Gelu", kMSDomain);
|
||||
|
|
|
|||
|
|
@ -2095,7 +2095,7 @@ Example 4:
|
|||
.Input(1, "X", "Input data tensor from the forward path", "T")
|
||||
.Input(2, "scale", "Scale tensor.", "T")
|
||||
.Input(3, "mean", "mean of X.", "U")
|
||||
.Input(4, "inv_std_var", "inverse std variance of X.", "U")
|
||||
.Input(4, "inv_std_dev", "inverse std deviation of X.", "U")
|
||||
.Output(0, "X_grad", "Gradient of the input.", "T")
|
||||
.Output(1, "scale_grad", "Gradient of the scale.", "T")
|
||||
.Output(2, "bias_grad", "Gradient of the bias.", "T")
|
||||
|
|
@ -2115,7 +2115,59 @@ Example 4:
|
|||
// The bias tensor has the same shape of the scale tensor.
|
||||
propagateElemTypeFromInputToOutput(ctx, 2, 2);
|
||||
propagateShapeFromInputToOutput(ctx, 2, 2);
|
||||
});
|
||||
})
|
||||
.SetContextDependentFunctionBodyBuilder(
|
||||
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
|
||||
FunctionBuilder builder(functionProto);
|
||||
|
||||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type()))
|
||||
return false;
|
||||
int64_t T = tp->tensor_type().elem_type();
|
||||
|
||||
// Requirements/assumptions:
|
||||
// Inputs Y_grad and X are of shape [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]] and type T
|
||||
// Input scale is of shape [d[axis], ..., d[rank-1]] and type U
|
||||
// Inputs mean and inv_std_dev are of shape [d[0], ..., d[axis-1], 1, ..., 1] (same rank as X)
|
||||
// and type U.
|
||||
//
|
||||
auto axis_ref_attr = MakeRefAttribute("axis", AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
|
||||
builder
|
||||
.AddOpset("", 15)
|
||||
.Add("cast_mean = Cast (mean)", "to", T)
|
||||
.Add("cast_inv_std_dev = Cast(inv_std_dev)", "to", T)
|
||||
.Add("x_2d = Flatten (X)", axis_ref_attr)
|
||||
.Add("Y_grad_2d = Flatten (Y_grad)", axis_ref_attr)
|
||||
.Add("mean_2d = Flatten (cast_mean)", axis_ref_attr)
|
||||
.Add("inv_std_dev_2d = Flatten (cast_inv_std_dev)", axis_ref_attr)
|
||||
.Add(R"ONNX(
|
||||
shape_x = Shape (X)
|
||||
bias_scale_shape = Shape (scale)
|
||||
scale_2d = Flatten <axis = 0> (scale)
|
||||
|
||||
axis_0 = Constant <value = int64[1] {0}> ()
|
||||
bias_grad_2d = ReduceSum (Y_grad_2d, axis_0)
|
||||
bias_grad = Reshape (bias_grad_2d, bias_scale_shape)
|
||||
|
||||
deviation = Sub (x_2d, mean_2d)
|
||||
normalized_deviation = Mul(deviation, inv_std_dev_2d)
|
||||
scale_grad_rows = Mul (Y_grad_2d, normalized_deviation)
|
||||
scale_grad_2d = ReduceSum (scale_grad_rows, axis_0)
|
||||
scale_grad = Reshape (scale_grad_2d, bias_scale_shape)
|
||||
normalized_layer_grad = Mul (Y_grad_2d, scale_2d)
|
||||
|
||||
B = Mul (normalized_layer_grad, inv_std_dev_2d)
|
||||
C = Mul (B, normalized_deviation)
|
||||
mean_B = ReduceMean <axes = [1]> (B)
|
||||
mean_C = ReduceMean <axes = [1]> (C)
|
||||
nd_mean_C = Mul (normalized_deviation, mean_C)
|
||||
mean_diff_B = Sub (B, mean_B)
|
||||
X_grad_2D = Sub (mean_diff_B, nd_mean_C)
|
||||
X_grad = Reshape (X_grad_2D, shape_x)
|
||||
)ONNX");
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SimplifiedLayerNormalizationGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
|
|||
|
|
@ -171,7 +171,6 @@ TEST_F(FunExpansionTest, DropoutGrad_WithRatio2) {
|
|||
|
||||
template <typename T, bool RunTest = true>
|
||||
void TestUnaryOpGrad(const char* opname) {
|
||||
|
||||
FunctionTestCase testCase(opname);
|
||||
std::vector<int64_t> shape{16, 4};
|
||||
testCase.AddInput<T, RunTest>("dY", shape);
|
||||
|
|
@ -199,5 +198,38 @@ TEST_F(FunExpansionTest, FastGeluGrad) {
|
|||
TestUnaryOpGrad<MLFloat16, false>("FastGeluGrad");
|
||||
}
|
||||
|
||||
template <typename T, typename U, bool RunTest = true>
|
||||
void TestLayerNormGrad(std::vector<int64_t> prefix_shape, std::vector<int64_t> suffix_shape) {
|
||||
FunctionTestCase testCase("LayerNormalizationGrad");
|
||||
std::vector<int64_t> input_shape(prefix_shape);
|
||||
for (auto d : suffix_shape)
|
||||
input_shape.push_back(d);
|
||||
std::vector<int64_t> stats_shape(prefix_shape);
|
||||
for (auto d : suffix_shape) {
|
||||
(void)d;
|
||||
stats_shape.push_back(1);
|
||||
}
|
||||
testCase.AddInput<T, RunTest>("Y_grad", input_shape);
|
||||
testCase.AddInput<T, RunTest>("X", input_shape);
|
||||
testCase.AddInput<T, RunTest>("scale", suffix_shape);
|
||||
testCase.AddInput<U, RunTest>("mean", stats_shape);
|
||||
testCase.AddInput<U, RunTest>("inv_std_dev", stats_shape);
|
||||
testCase.AddOutput("X_grad");
|
||||
testCase.AddOutput("scale_grad");
|
||||
testCase.AddOutput("bias_grad");
|
||||
testCase.AddAttribute("axis", prefix_shape.size());
|
||||
if (RunTest)
|
||||
testCase.RunTest();
|
||||
else
|
||||
// Test only expanded model creation and model checking.
|
||||
testCase.CreateModel(true);
|
||||
}
|
||||
|
||||
TEST_F(FunExpansionTest, LayerNormalizationGrad) {
|
||||
TestLayerNormGrad<float, float, true>({4, 1}, {8, 4});
|
||||
TestLayerNormGrad<float, float, true>({}, {8, 4});
|
||||
TestLayerNormGrad<BFloat16, float, false>({}, {8, 4});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue