From 7ddeafdfcc498681f3a665ffa609646bf52edeaf Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 10 Dec 2020 11:03:26 +0800 Subject: [PATCH] Add ReduceL2Grad and ClipGrad (#5970) * ReduceL2Grad and ClipGrad. * fix win build and amd ci pipeline * resolve comments. Co-authored-by: Vincent Wang --- .../core/framework/gradient_graph_builder.h | 3 +- .../core/graph/gradient_builder.cc | 62 +++++++++++++++++++ .../orttraining/core/graph/gradient_builder.h | 2 + .../core/graph/gradient_builder_registry.cc | 2 + .../test/gradient/gradient_ops_test.cc | 57 +++++++++++++++++ .../github/pai/pai-excluded-tests.txt | 4 +- 6 files changed, 128 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index e71278df82..f475c17dca 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -65,7 +65,8 @@ static std::unordered_map> {"Squeeze", {1}}, {"Unsqueeze", {1}}, {"ReduceSum", {1}}, - {"Split", {1}}}; + {"Split", {1}}, + {"Clip", {1, 2}}}; class GradientGraphBuilder { public: diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index cb8454ee58..57193591bf 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1101,6 +1101,35 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) { return result; } +IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) { + std::vector result; + auto attributes = SrcNodeAttributes(); + bool keepdims = true; + if (attributes.find("keepdims") != attributes.end() && attributes.at("keepdims").has_i()) { + keepdims = static_cast(attributes.at("keepdims").i()); + } + + result.emplace_back(NodeDef("Div", {GO(0), O(0)}, {IA("Scaled_dY")})); + + // Handle 0 elements in Y. + NodeDef zero_constant_node = ZeroConstantNode(IElemType(0)); + ArgDef ZERO = zero_constant_node.output_args[0]; + result.push_back(zero_constant_node); + result.emplace_back(NodeDef("Equal", {O(0), ZERO}, {IA("Masked_Y")})); + ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY"); + result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def})); + + if (!keepdims && attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY"); + result.emplace_back( + NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)})); + } + + result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)})); + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) { std::vector result; auto attributes = SrcNodeAttributes(); @@ -1444,5 +1473,38 @@ IMPLEMENT_GRADIENT_BUILDER(GetTopKGradient) { {MakeAttribute("axis", axis)})}; } +IMPLEMENT_GRADIENT_BUILDER(GetClipGradient) { + std::vector output; + size_t numInputs = GetSrcNodeInputSize(); + bool has_i1 = false, has_i2 = false; + ArgDef intermediate_arg_def = ArgDef(""); + // Gradients not defined on min and max, so we return the subgradient 1 for these cases. + if (numInputs >= 2 && I(1).Exists()) { + has_i1 = true; + intermediate_arg_def = IA("Masked_Min"); + output.emplace_back(NodeDef("GreaterOrEqual", {I(0), I(1)}, {intermediate_arg_def})); + } + + if (numInputs >= 3 && I(2).Exists()) { + has_i2 = true; + intermediate_arg_def = IA("Masked_Max"); + output.emplace_back(NodeDef("LessOrEqual", {I(0), I(2)}, {intermediate_arg_def})); + if (has_i1) { + intermediate_arg_def = IA("Masked_Min_Max"); + output.emplace_back(NodeDef("And", {IA("Masked_Min"), IA("Masked_Max")}, {intermediate_arg_def})); + } + } + + if (!has_i1 && !has_i2) { + output.emplace_back(NodeDef("Identity", {GO(0)}, {GI(0)})); + } else { + output.emplace_back( + NodeDef("Cast", {intermediate_arg_def}, {IA("Casted_Mask")}, {MakeAttribute("to", int64_t(IElemType(0)))})); + output.emplace_back(NodeDef("Mul", {GO(0), IA("Casted_Mask")}, {GI(0)})); + } + + return output; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 3c6dc7bf5f..6bc9207cf0 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -28,6 +28,7 @@ DECLARE_GRADIENT_BUILDER(GetNegGradient) DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient) DECLARE_GRADIENT_BUILDER(GetReduceSumGradient) DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient) +DECLARE_GRADIENT_BUILDER(GetReduceL2Gradient) DECLARE_GRADIENT_BUILDER(GetPowGradient) DECLARE_GRADIENT_BUILDER(GetConcatGradient) DECLARE_GRADIENT_BUILDER(GetConcatTrainingGradient) @@ -68,6 +69,7 @@ DECLARE_GRADIENT_BUILDER(GetExpandGradient) DECLARE_GRADIENT_BUILDER(GetExpGradient) DECLARE_GRADIENT_BUILDER(GetFlattenGradient) DECLARE_GRADIENT_BUILDER(GetTopKGradient) +DECLARE_GRADIENT_BUILDER(GetClipGradient) } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index f6f5218bc6..335dc4744f 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -56,6 +56,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient); REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient); REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient); + REGISTER_GRADIENT_BUILDER("ReduceL2", GetReduceL2Gradient); REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient); REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient); REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient); @@ -99,6 +100,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient); REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient); REGISTER_GRADIENT_BUILDER("TopK", GetTopKGradient); + REGISTER_GRADIENT_BUILDER("Clip", GetClipGradient); }; } // namespace training diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index e0043eb281..63be72b6cc 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -593,6 +593,28 @@ TEST(GradientCheckerTest, ReduceSumGrad) { RunReductionTests(op_def_13, true, true); } +TEST(GradientCheckerTest, ReduceL2Grad) { + // Attribute axes supports negative values from opset 11. + OpDef op_def{"ReduceL2", kOnnxDomain, 11}; + + RunReductionTests(op_def); + + // Y with 0 elements case. + { + float max_error; + GradientChecker gradient_checker; + + TensorInfo x_info({4, 2}, true); + std::vector> x_datas = {{1, 1, 0, 0, 3, 0, 0, 0}}; + + TensorInfo y_info({4, 1}, true); + std::vector axes{-1}; + gradient_checker.ComputeGradientError(op_def, {x_info}, {y_info}, &max_error, x_datas, + {MakeAttribute("axes", axes)}); + EXPECT_IS_TINY(max_error); + } +} + TEST(GradientCheckerTest, ReduceLogSumExpGrad) { // Attribute axes supports negative values from opset 11. OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11}; @@ -2158,6 +2180,41 @@ TEST(GradientCheckerTest, TopKGrad) { } } +TEST(GradientCheckerTest, ClipGrad) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"Clip", kOnnxDomain, 12}; + + { + TensorInfo x_info({2, 2, 2}, true); + TensorInfo min_info({}, false); + TensorInfo max_info({}, false); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {2.8f}, {7.2f}}; + TensorInfo y_info({2, 2, 2}, true); + gradient_checker.ComputeGradientError(op_def, {x_info, min_info, max_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo x_info({2, 2, 2}, true); + TensorInfo min_info({}, false); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {3.8f}}; + TensorInfo y_info({2, 2, 2}, true); + gradient_checker.ComputeGradientError(op_def, {x_info, min_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } + + // Should have a case with Op(x, null, max), but current ComputeGradientError doesn't support doing this. + + { + TensorInfo x_info({2, 2, 2}, true); + std::vector> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}}; + TensorInfo y_info({2, 2, 2}, true); + gradient_checker.ComputeGradientError(op_def, {x_info}, {y_info}, &max_error, x_datas); + EXPECT_IS_TINY(max_error); + } +} + } // namespace test } // namespace onnxruntime diff --git a/tools/ci_build/github/pai/pai-excluded-tests.txt b/tools/ci_build/github/pai/pai-excluded-tests.txt index 7330fe45c1..382102ce74 100644 --- a/tools/ci_build/github/pai/pai-excluded-tests.txt +++ b/tools/ci_build/github/pai/pai-excluded-tests.txt @@ -143,10 +143,12 @@ GradientCheckerTest.MatMulGrad GradientCheckerTest.ReduceMeanGrad GradientCheckerTest.ReduceSumGrad GradientCheckerTest.ReduceLogSumExpGrad +GradientCheckerTest.ReduceL2Grad GradientCheckerTest.SoftmaxCrossEntropyGrad GradientCheckerTest.ExpandGrad GradientCheckerTest.DivGrad GradientCheckerTest.GemmGrad GradientCheckerTest.SplitGrad GradientCheckerTest.SqueezeGrad -GradientCheckerTest.UnsqueezeGrad +GradientCheckerTest.UnsqueezeGrad +GradientCheckerTest.ClipGrad