From 26cd3c1fb0245d05e3beb8a9f33ce5f5d274d111 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 19 Mar 2024 09:33:06 -0700 Subject: [PATCH] add kernel tests for ops that changed in opset18 (#19767) ### Description - [x] Pad operator has introduced a new input called "axes" which specifies which axis to pad. But it defaults to input_rank if axes is not provided which was the behavior before the opset upgrade. - [x] ReduceMean - [x] ReduceL2 - [x] ReduceLogSumExp - [x] ReduceSum - Reduction ops all had the axes attribute switched to an input and a new attribute called "noop_with_empty_axes" was added to define what to do when axes is not specified. - [x] Resize has had two new attributes introduced: antialias and keep_aspect_ratio_policy. From Operators.md I've gathered: "Antialiasing is achieved by stretching the resampling filter by a factor max(1, 1 / scale), which means that when downsampling, more input pixels contribute to an output pixel." keep_aspect_ratio_policy "describes how to interpret the `sizes` input with regard to keeping the original aspect ratio of the input." there are a couple enum-type options that specify different policies and what to do in each case. - NOTE: Baiju already included opset18 tests in https://github.com/microsoft/onnxruntime/pull/17772 - [x] ScatterElements/ScatterND has had a new attribute introduced called "reduction." This specifies the type of reduction to apply: none (default), add, mul, max, min. - [x] Split introduced a new attribute called "num_outputs" which specifies how many outputs to split the input tensor into. This is in contrast to the previous, default behavior of specifying a "split" input which defines the size of each resultant tensor of the output. ### Motivation and Context --- .../core/graph/gradient_builder.cc | 37 ++++++++++++++----- .../test/gradient/gradient_ops_test.cc | 30 +++++++++++++-- 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index e675b55c8a..22dcf4eb92 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1112,6 +1112,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { ArgDef grad = GO(0); if (!keepdims) { + size_t numInputs = GetSrcNodeInputSize(); if (attributes.find("axes") != attributes.end()) { std::vector axes_values = RetrieveValues(attributes.at("axes")); grad = IA("Unqueezed_Grad"); @@ -1122,6 +1123,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { result.push_back(axes_values_node); result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad})); } + } else if (numInputs == 2) { // optional input 'axes' is available as input I(1) + grad = IA("Unqueezed_Grad"); + result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad})); } } @@ -1152,12 +1156,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) { } ArgDef grad = GO(0); - if (!keepdims && attributes.find("axes") != attributes.end()) { - std::vector axes_values = RetrieveValues(attributes.at("axes")); - grad = IA("Unsqueezed_Grad"); - result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + if (!keepdims) { + size_t numInputs = GetSrcNodeInputSize(); + if (attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + grad = IA("Unsqueezed_Grad"); - result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)})); + result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + + result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)})); + } else if (numInputs == 2) { // optional input 'axes' is available as input I(1) + grad = IA("Unsqueezed_Grad"); + result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad})); + + result.push_back(NodeDef("Unsqueeze", {O(0), I(1)}, {IA("Unsqueezed_Output")})); + } result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")})); } else { result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")})); @@ -1188,11 +1201,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) { ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY"); result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def})); - if (!keepdims && attributes.find("axes") != attributes.end()) { - std::vector axes_values = RetrieveValues(attributes.at("axes")); + if (!keepdims) { + size_t numInputs = GetSrcNodeInputSize(); scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY"); - result.emplace_back( - NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)})); + if (attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + result.emplace_back( + NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)})); + } else if (numInputs == 2) { // optional input 'axes' is available as input I(1) + result.emplace_back( + NodeDef("Unsqueeze", {IA("Masked_Scaled_dY"), I(1)}, {scaled_dy_arg_def})); + } } result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)})); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index feca94ae27..94ca96c68f 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -607,6 +607,10 @@ TEST(GradientCheckerTest, ReduceMeanGrad) { OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13}; RunReductionTests(op_def_opset13); + + // axes is input from opset 18. + OpDef op_def_opset18{"ReduceMean", kOnnxDomain, 18}; + RunReductionTests(op_def_opset18, true, true); } TEST(GradientCheckerTest, ReduceSumGrad) { @@ -619,6 +623,10 @@ TEST(GradientCheckerTest, ReduceSumGrad) { OpDef op_def_13{"ReduceSum", kOnnxDomain, 13}; RunReductionTests(op_def_13, true, true); + + OpDef op_def_18{"ReduceSum", kOnnxDomain, 18}; + + RunReductionTests(op_def_18, true, true); } TEST(GradientCheckerTest, ReduceL2Grad) { @@ -641,6 +649,11 @@ TEST(GradientCheckerTest, ReduceL2Grad) { {MakeAttribute("axes", axes)})); EXPECT_IS_TINY(max_error); } + + // axes is input from opset 18 + OpDef op_def_18{"ReduceL2", kOnnxDomain, 18}; + + RunReductionTests(op_def_18, true, true); } TEST(GradientCheckerTest, ReduceLogSumExpGrad) { @@ -648,6 +661,10 @@ TEST(GradientCheckerTest, ReduceLogSumExpGrad) { OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11}; RunReductionTests(op_def); + + OpDef op_def_opset18{"ReduceLogSumExp", kOnnxDomain, 18}; + + RunReductionTests(op_def_opset18, true, true); } TEST(GradientCheckerTest, ReluGrad) { @@ -698,6 +715,13 @@ TEST(GradientCheckerTest, SplitGrad) { ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_13, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error, {MakeAttribute("axis", int64_t(0))})); EXPECT_IS_TINY(max_error); + + // opset18 test + OpDef op_def_18{"Split", kOnnxDomain, 18}; + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_18, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error, + {MakeAttribute("axis", int64_t(0)), + MakeAttribute("num_outputs", int64_t(3))})); + EXPECT_IS_TINY(max_error); } template @@ -2733,7 +2757,7 @@ TEST(GradientCheckerTest, TileGrad) { TEST(GradientCheckerTest, PadGrad) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"Pad", kOnnxDomain, 11}; + OpDef op_def{"Pad", kOnnxDomain, 18}; { TensorInfo x_info({2, 4}, true); @@ -2803,7 +2827,7 @@ TEST(GradientCheckerTest, PadGrad) { TEST(GradientCheckerTest, ScatterNDGrad) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"ScatterND", kOnnxDomain, 11}; + OpDef op_def{"ScatterND", kOnnxDomain, 18}; { TensorInfo data_info({8}, true); @@ -2887,7 +2911,7 @@ TEST(GradientCheckerTest, ScatterNDGrad) { TEST(GradientCheckerTest, ScatterElementsGrad) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"ScatterElements", kOnnxDomain, 13}; + OpDef op_def{"ScatterElements", kOnnxDomain, 18}; { // without axis TensorInfo data_info({3, 3}, true);