Add ReduceL2Grad and ClipGrad (#5970)

* ReduceL2Grad and ClipGrad.

* fix win build and amd ci pipeline

* resolve comments.

Co-authored-by: Vincent Wang <weicwang@AiFramework2080ti2.corp.microsoft.com>
This commit is contained in:
Vincent Wang 2020-12-10 11:03:26 +08:00 committed by GitHub
parent 404982ded5
commit 7ddeafdfcc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 2 deletions

View file

@ -65,7 +65,8 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
{"Squeeze", {1}},
{"Unsqueeze", {1}},
{"ReduceSum", {1}},
{"Split", {1}}};
{"Split", {1}},
{"Clip", {1, 2}}};
class GradientGraphBuilder {
public:

View file

@ -1101,6 +1101,35 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
bool keepdims = true;
if (attributes.find("keepdims") != attributes.end() && attributes.at("keepdims").has_i()) {
keepdims = static_cast<bool>(attributes.at("keepdims").i());
}
result.emplace_back(NodeDef("Div", {GO(0), O(0)}, {IA("Scaled_dY")}));
// Handle 0 elements in Y.
NodeDef zero_constant_node = ZeroConstantNode(IElemType(0));
ArgDef ZERO = zero_constant_node.output_args[0];
result.push_back(zero_constant_node);
result.emplace_back(NodeDef("Equal", {O(0), ZERO}, {IA("Masked_Y")}));
ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY");
result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def}));
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY");
result.emplace_back(
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
}
result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)}));
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
@ -1444,5 +1473,38 @@ IMPLEMENT_GRADIENT_BUILDER(GetTopKGradient) {
{MakeAttribute("axis", axis)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetClipGradient) {
std::vector<NodeDef> output;
size_t numInputs = GetSrcNodeInputSize();
bool has_i1 = false, has_i2 = false;
ArgDef intermediate_arg_def = ArgDef("");
// Gradients not defined on min and max, so we return the subgradient 1 for these cases.
if (numInputs >= 2 && I(1).Exists()) {
has_i1 = true;
intermediate_arg_def = IA("Masked_Min");
output.emplace_back(NodeDef("GreaterOrEqual", {I(0), I(1)}, {intermediate_arg_def}));
}
if (numInputs >= 3 && I(2).Exists()) {
has_i2 = true;
intermediate_arg_def = IA("Masked_Max");
output.emplace_back(NodeDef("LessOrEqual", {I(0), I(2)}, {intermediate_arg_def}));
if (has_i1) {
intermediate_arg_def = IA("Masked_Min_Max");
output.emplace_back(NodeDef("And", {IA("Masked_Min"), IA("Masked_Max")}, {intermediate_arg_def}));
}
}
if (!has_i1 && !has_i2) {
output.emplace_back(NodeDef("Identity", {GO(0)}, {GI(0)}));
} else {
output.emplace_back(
NodeDef("Cast", {intermediate_arg_def}, {IA("Casted_Mask")}, {MakeAttribute("to", int64_t(IElemType(0)))}));
output.emplace_back(NodeDef("Mul", {GO(0), IA("Casted_Mask")}, {GI(0)}));
}
return output;
}
} // namespace training
} // namespace onnxruntime

View file

@ -28,6 +28,7 @@ DECLARE_GRADIENT_BUILDER(GetNegGradient)
DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient)
DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
DECLARE_GRADIENT_BUILDER(GetReduceL2Gradient)
DECLARE_GRADIENT_BUILDER(GetPowGradient)
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
DECLARE_GRADIENT_BUILDER(GetConcatTrainingGradient)
@ -68,6 +69,7 @@ DECLARE_GRADIENT_BUILDER(GetExpandGradient)
DECLARE_GRADIENT_BUILDER(GetExpGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenGradient)
DECLARE_GRADIENT_BUILDER(GetTopKGradient)
DECLARE_GRADIENT_BUILDER(GetClipGradient)
} // namespace training
} // namespace onnxruntime

View file

@ -56,6 +56,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient);
REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient);
REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient);
REGISTER_GRADIENT_BUILDER("ReduceL2", GetReduceL2Gradient);
REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);
@ -99,6 +100,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient);
REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient);
REGISTER_GRADIENT_BUILDER("TopK", GetTopKGradient);
REGISTER_GRADIENT_BUILDER("Clip", GetClipGradient);
};
} // namespace training

View file

@ -593,6 +593,28 @@ TEST(GradientCheckerTest, ReduceSumGrad) {
RunReductionTests(op_def_13, true, true);
}
TEST(GradientCheckerTest, ReduceL2Grad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceL2", kOnnxDomain, 11};
RunReductionTests(op_def);
// Y with 0 elements case.
{
float max_error;
GradientChecker<float, float, float> gradient_checker;
TensorInfo x_info({4, 2}, true);
std::vector<std::vector<float>> x_datas = {{1, 1, 0, 0, 3, 0, 0, 0}};
TensorInfo y_info({4, 1}, true);
std::vector<int64_t> axes{-1};
gradient_checker.ComputeGradientError(op_def, {x_info}, {y_info}, &max_error, x_datas,
{MakeAttribute("axes", axes)});
EXPECT_IS_TINY(max_error);
}
}
TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};
@ -2158,6 +2180,41 @@ TEST(GradientCheckerTest, TopKGrad) {
}
}
TEST(GradientCheckerTest, ClipGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Clip", kOnnxDomain, 12};
{
TensorInfo x_info({2, 2, 2}, true);
TensorInfo min_info({}, false);
TensorInfo max_info({}, false);
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {2.8f}, {7.2f}};
TensorInfo y_info({2, 2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, min_info, max_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
{
TensorInfo x_info({2, 2, 2}, true);
TensorInfo min_info({}, false);
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {3.8f}};
TensorInfo y_info({2, 2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {x_info, min_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
// Should have a case with Op(x, null, max), but current ComputeGradientError doesn't support doing this.
{
TensorInfo x_info({2, 2, 2}, true);
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}};
TensorInfo y_info({2, 2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {x_info}, {y_info}, &max_error, x_datas);
EXPECT_IS_TINY(max_error);
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -143,10 +143,12 @@ GradientCheckerTest.MatMulGrad
GradientCheckerTest.ReduceMeanGrad
GradientCheckerTest.ReduceSumGrad
GradientCheckerTest.ReduceLogSumExpGrad
GradientCheckerTest.ReduceL2Grad
GradientCheckerTest.SoftmaxCrossEntropyGrad
GradientCheckerTest.ExpandGrad
GradientCheckerTest.DivGrad
GradientCheckerTest.GemmGrad
GradientCheckerTest.SplitGrad
GradientCheckerTest.SqueezeGrad
GradientCheckerTest.UnsqueezeGrad
GradientCheckerTest.UnsqueezeGrad
GradientCheckerTest.ClipGrad