diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4e5eb5ea20..429c9b18ad 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2441,13 +2441,15 @@ Example 4: if (ctx.getNumOutputs() > 1) { auto saved_mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); saved_mean_shape->CopyFrom(input_shape); - saved_mean_shape->mutable_dim(static_cast(axis))->set_dim_value(1); + for (int d = static_cast(axis); d < input_ndim; ++d) + saved_mean_shape->mutable_dim(d)->set_dim_value(1); } if (ctx.getNumOutputs() > 2) { auto saved_inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape(); saved_inv_std_dev_shape->CopyFrom(input_shape); - saved_inv_std_dev_shape->mutable_dim(static_cast(axis))->set_dim_value(1); + for (int d = static_cast(axis); d < input_ndim; ++d) + saved_inv_std_dev_shape->mutable_dim(d)->set_dim_value(1); } }) .SetContextDependentFunctionBodyBuilder( @@ -2507,9 +2509,11 @@ Example 4: {{"Deviation"}, "Sub", {"XU", "Mean2D"}}, {{"Normalized"}, "Div", {"Deviation", "StdDev"}}, {{"NormalizedT"}, "Cast", {"Normalized"}, {{"to", T}}}, - {{"Scaled"}, "Mul", {"NormalizedT", "Scale"}}}; + {{"Scale2D"}, "Flatten", {"Scale"}, {{"axis", int64_t(0)}}}, + {{"Scaled"}, "Mul", {"NormalizedT", "Scale2D"}}}; if (ctx.hasInput(2)) { - body.push_back({{"Biased"}, "Add", {"Scaled", "B"}}); + body.push_back({{"B2D"}, "Flatten", {"B"}, {{"axis", int64_t(0)}}}); + body.push_back({{"Biased"}, "Add", {"Scaled", "B2D"}}); } else { body.push_back({{"Biased"}, "Identity", {"Scaled"}}); } diff --git a/onnxruntime/core/graph/contrib_ops/onnx_function_util.h b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h index d5b98bdb79..61694eafe5 100644 --- a/onnxruntime/core/graph/contrib_ops/onnx_function_util.h +++ b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h @@ -9,6 +9,7 @@ #include "onnx/onnx-operators_pb.h" #include "onnx/defs/schema.h" #include "onnx/defs/function.h" +#include "onnx/defs/parser.h" namespace ONNX_NAMESPACE { @@ -20,4 +21,55 @@ inline static FunctionBodyHelper::NodeDef Const(const std::string& name, double return FunctionBodyHelper::NodeDef{ {name}, "Constant", {}, {{"value", ToTensor(value, elem_type)}}}; } + +class FunctionBuilder { + public: + FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {} + + FunctionBuilder& Add(const char* nodes_txt) { + OnnxParser parser(nodes_txt); + auto& nodes = *funProto.mutable_node(); + + while (!parser.EndOfInput()) { + auto status = parser.Parse(*nodes.Add()); + if (!status.IsOK()) + ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage())); + } + + return *this; + } + + FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) { + OnnxParser parser(node_txt); + auto& node = *funProto.add_node(); + auto status = parser.Parse(node); + if (!status.IsOK()) { + ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage())); + } + + if (!parser.EndOfInput()) { + ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage())); + } + + *node.add_attribute() = attr; + + return *this; + } + + template + FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, T attr_value) { + return Add (node_txt, MakeAttribute(attr_name, attr_value)); + } + + FunctionBuilder& AddOpset(const char* domain, int version) { + auto* opset = funProto.add_opset_import(); + opset->set_domain(domain); + opset->set_version(version); + return *this; + } + + private: + FunctionProto& funProto; +}; + } // namespace ONNX_NAMESPACE \ No newline at end of file diff --git a/onnxruntime/test/contrib_ops/function_ops_test.cc b/onnxruntime/test/contrib_ops/function_ops_test.cc index bae5a9c97b..db7e22904e 100644 --- a/onnxruntime/test/contrib_ops/function_ops_test.cc +++ b/onnxruntime/test/contrib_ops/function_ops_test.cc @@ -27,10 +27,8 @@ class ContribFunExpansionTest : public ::testing::Test { }; template -void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) { +void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true, std::vector shape1 = {8, 16}, std::vector shape2 = {16}, int64_t axis = -1) { FunctionTestCase testCase("LayerNormalization", kOnnxDomain); - std::vector shape1{8, 16}; - std::vector shape2{16}; testCase.AddInput("x", shape1); testCase.AddInput("scale", shape2); @@ -39,6 +37,8 @@ void CheckLayerNorm(bool compute_mean = true, bool compute_isd = true) { testCase.AddOutput(compute_mean ? "mean" : ""); testCase.AddOutput(compute_isd ? "invstddev" : ""); testCase.AddAttribute("stash_type", data_types_internal::ToTensorDataType()); + if (axis != -1) + testCase.AddAttribute("axis", axis); if (RunTest) testCase.RunTest(); else @@ -59,6 +59,11 @@ TEST_F(ContribFunExpansionTest, LayerNorm_OptionalOutputs) { CheckLayerNorm(true, false); } +TEST_F(ContribFunExpansionTest, LayerNorm_OtherShapes) { + // Test expand-and-run + CheckLayerNorm(true, true, {4, 2, 8}, {2, 8}, 1); +} + template void CheckGelu() { FunctionTestCase testCase("Gelu", kMSDomain); diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 9efe5cbb61..6f2c818f4d 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -2095,7 +2095,7 @@ Example 4: .Input(1, "X", "Input data tensor from the forward path", "T") .Input(2, "scale", "Scale tensor.", "T") .Input(3, "mean", "mean of X.", "U") - .Input(4, "inv_std_var", "inverse std variance of X.", "U") + .Input(4, "inv_std_dev", "inverse std deviation of X.", "U") .Output(0, "X_grad", "Gradient of the input.", "T") .Output(1, "scale_grad", "Gradient of the scale.", "T") .Output(2, "bias_grad", "Gradient of the bias.", "T") @@ -2115,7 +2115,59 @@ Example 4: // The bias tensor has the same shape of the scale tensor. propagateElemTypeFromInputToOutput(ctx, 2, 2); propagateShapeFromInputToOutput(ctx, 2, 2); - }); + }) + .SetContextDependentFunctionBodyBuilder( + [](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) { + FunctionBuilder builder(functionProto); + + auto* tp = ctx.getInputType(0); + if ((tp == nullptr) || (!tp->has_tensor_type())) + return false; + int64_t T = tp->tensor_type().elem_type(); + + // Requirements/assumptions: + // Inputs Y_grad and X are of shape [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]] and type T + // Input scale is of shape [d[axis], ..., d[rank-1]] and type U + // Inputs mean and inv_std_dev are of shape [d[0], ..., d[axis-1], 1, ..., 1] (same rank as X) + // and type U. + // + auto axis_ref_attr = MakeRefAttribute("axis", AttributeProto_AttributeType::AttributeProto_AttributeType_INT); + builder + .AddOpset("", 15) + .Add("cast_mean = Cast (mean)", "to", T) + .Add("cast_inv_std_dev = Cast(inv_std_dev)", "to", T) + .Add("x_2d = Flatten (X)", axis_ref_attr) + .Add("Y_grad_2d = Flatten (Y_grad)", axis_ref_attr) + .Add("mean_2d = Flatten (cast_mean)", axis_ref_attr) + .Add("inv_std_dev_2d = Flatten (cast_inv_std_dev)", axis_ref_attr) + .Add(R"ONNX( + shape_x = Shape (X) + bias_scale_shape = Shape (scale) + scale_2d = Flatten (scale) + + axis_0 = Constant () + bias_grad_2d = ReduceSum (Y_grad_2d, axis_0) + bias_grad = Reshape (bias_grad_2d, bias_scale_shape) + + deviation = Sub (x_2d, mean_2d) + normalized_deviation = Mul(deviation, inv_std_dev_2d) + scale_grad_rows = Mul (Y_grad_2d, normalized_deviation) + scale_grad_2d = ReduceSum (scale_grad_rows, axis_0) + scale_grad = Reshape (scale_grad_2d, bias_scale_shape) + normalized_layer_grad = Mul (Y_grad_2d, scale_2d) + + B = Mul (normalized_layer_grad, inv_std_dev_2d) + C = Mul (B, normalized_deviation) + mean_B = ReduceMean (B) + mean_C = ReduceMean (C) + nd_mean_C = Mul (normalized_deviation, mean_C) + mean_diff_B = Sub (B, mean_B) + X_grad_2D = Sub (mean_diff_B, nd_mean_C) + X_grad = Reshape (X_grad_2D, shape_x) + )ONNX"); + schema.BuildFunction(functionProto); + return true; + }); ONNX_CONTRIB_OPERATOR_SCHEMA(SimplifiedLayerNormalizationGrad) .SetDomain(kMSDomain) diff --git a/orttraining/orttraining/test/gradient/function_ops_test.cc b/orttraining/orttraining/test/gradient/function_ops_test.cc index 5df08a59b5..9245bac3b8 100644 --- a/orttraining/orttraining/test/gradient/function_ops_test.cc +++ b/orttraining/orttraining/test/gradient/function_ops_test.cc @@ -171,7 +171,6 @@ TEST_F(FunExpansionTest, DropoutGrad_WithRatio2) { template void TestUnaryOpGrad(const char* opname) { - FunctionTestCase testCase(opname); std::vector shape{16, 4}; testCase.AddInput("dY", shape); @@ -199,5 +198,38 @@ TEST_F(FunExpansionTest, FastGeluGrad) { TestUnaryOpGrad("FastGeluGrad"); } +template +void TestLayerNormGrad(std::vector prefix_shape, std::vector suffix_shape) { + FunctionTestCase testCase("LayerNormalizationGrad"); + std::vector input_shape(prefix_shape); + for (auto d : suffix_shape) + input_shape.push_back(d); + std::vector stats_shape(prefix_shape); + for (auto d : suffix_shape) { + (void)d; + stats_shape.push_back(1); + } + testCase.AddInput("Y_grad", input_shape); + testCase.AddInput("X", input_shape); + testCase.AddInput("scale", suffix_shape); + testCase.AddInput("mean", stats_shape); + testCase.AddInput("inv_std_dev", stats_shape); + testCase.AddOutput("X_grad"); + testCase.AddOutput("scale_grad"); + testCase.AddOutput("bias_grad"); + testCase.AddAttribute("axis", prefix_shape.size()); + if (RunTest) + testCase.RunTest(); + else + // Test only expanded model creation and model checking. + testCase.CreateModel(true); +} + +TEST_F(FunExpansionTest, LayerNormalizationGrad) { + TestLayerNormGrad({4, 1}, {8, 4}); + TestLayerNormGrad({}, {8, 4}); + TestLayerNormGrad({}, {8, 4}); +} + } // namespace test } // namespace onnxruntime