mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Add ReduceL2Grad and ClipGrad (#5970)
* ReduceL2Grad and ClipGrad. * fix win build and amd ci pipeline * resolve comments. Co-authored-by: Vincent Wang <weicwang@AiFramework2080ti2.corp.microsoft.com>
This commit is contained in:
parent
404982ded5
commit
7ddeafdfcc
6 changed files with 128 additions and 2 deletions
|
|
@ -65,7 +65,8 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
|
|||
{"Squeeze", {1}},
|
||||
{"Unsqueeze", {1}},
|
||||
{"ReduceSum", {1}},
|
||||
{"Split", {1}}};
|
||||
{"Split", {1}},
|
||||
{"Clip", {1, 2}}};
|
||||
|
||||
class GradientGraphBuilder {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -1101,6 +1101,35 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
|
|||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) {
|
||||
std::vector<NodeDef> result;
|
||||
auto attributes = SrcNodeAttributes();
|
||||
bool keepdims = true;
|
||||
if (attributes.find("keepdims") != attributes.end() && attributes.at("keepdims").has_i()) {
|
||||
keepdims = static_cast<bool>(attributes.at("keepdims").i());
|
||||
}
|
||||
|
||||
result.emplace_back(NodeDef("Div", {GO(0), O(0)}, {IA("Scaled_dY")}));
|
||||
|
||||
// Handle 0 elements in Y.
|
||||
NodeDef zero_constant_node = ZeroConstantNode(IElemType(0));
|
||||
ArgDef ZERO = zero_constant_node.output_args[0];
|
||||
result.push_back(zero_constant_node);
|
||||
result.emplace_back(NodeDef("Equal", {O(0), ZERO}, {IA("Masked_Y")}));
|
||||
ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY");
|
||||
result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def}));
|
||||
|
||||
if (!keepdims && attributes.find("axes") != attributes.end()) {
|
||||
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY");
|
||||
result.emplace_back(
|
||||
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
|
||||
}
|
||||
|
||||
result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)}));
|
||||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
|
||||
std::vector<NodeDef> result;
|
||||
auto attributes = SrcNodeAttributes();
|
||||
|
|
@ -1444,5 +1473,38 @@ IMPLEMENT_GRADIENT_BUILDER(GetTopKGradient) {
|
|||
{MakeAttribute("axis", axis)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetClipGradient) {
|
||||
std::vector<NodeDef> output;
|
||||
size_t numInputs = GetSrcNodeInputSize();
|
||||
bool has_i1 = false, has_i2 = false;
|
||||
ArgDef intermediate_arg_def = ArgDef("");
|
||||
// Gradients not defined on min and max, so we return the subgradient 1 for these cases.
|
||||
if (numInputs >= 2 && I(1).Exists()) {
|
||||
has_i1 = true;
|
||||
intermediate_arg_def = IA("Masked_Min");
|
||||
output.emplace_back(NodeDef("GreaterOrEqual", {I(0), I(1)}, {intermediate_arg_def}));
|
||||
}
|
||||
|
||||
if (numInputs >= 3 && I(2).Exists()) {
|
||||
has_i2 = true;
|
||||
intermediate_arg_def = IA("Masked_Max");
|
||||
output.emplace_back(NodeDef("LessOrEqual", {I(0), I(2)}, {intermediate_arg_def}));
|
||||
if (has_i1) {
|
||||
intermediate_arg_def = IA("Masked_Min_Max");
|
||||
output.emplace_back(NodeDef("And", {IA("Masked_Min"), IA("Masked_Max")}, {intermediate_arg_def}));
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_i1 && !has_i2) {
|
||||
output.emplace_back(NodeDef("Identity", {GO(0)}, {GI(0)}));
|
||||
} else {
|
||||
output.emplace_back(
|
||||
NodeDef("Cast", {intermediate_arg_def}, {IA("Casted_Mask")}, {MakeAttribute("to", int64_t(IElemType(0)))}));
|
||||
output.emplace_back(NodeDef("Mul", {GO(0), IA("Casted_Mask")}, {GI(0)}));
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ DECLARE_GRADIENT_BUILDER(GetNegGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReduceL2Gradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetPowGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConcatTrainingGradient)
|
||||
|
|
@ -68,6 +69,7 @@ DECLARE_GRADIENT_BUILDER(GetExpandGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetExpGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetFlattenGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetTopKGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetClipGradient)
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ReduceL2", GetReduceL2Gradient);
|
||||
REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);
|
||||
|
|
@ -99,6 +100,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Exp", GetExpGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Flatten", GetFlattenGradient);
|
||||
REGISTER_GRADIENT_BUILDER("TopK", GetTopKGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Clip", GetClipGradient);
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -593,6 +593,28 @@ TEST(GradientCheckerTest, ReduceSumGrad) {
|
|||
RunReductionTests(op_def_13, true, true);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReduceL2Grad) {
|
||||
// Attribute axes supports negative values from opset 11.
|
||||
OpDef op_def{"ReduceL2", kOnnxDomain, 11};
|
||||
|
||||
RunReductionTests(op_def);
|
||||
|
||||
// Y with 0 elements case.
|
||||
{
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
||||
TensorInfo x_info({4, 2}, true);
|
||||
std::vector<std::vector<float>> x_datas = {{1, 1, 0, 0, 3, 0, 0, 0}};
|
||||
|
||||
TensorInfo y_info({4, 1}, true);
|
||||
std::vector<int64_t> axes{-1};
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info}, {y_info}, &max_error, x_datas,
|
||||
{MakeAttribute("axes", axes)});
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
|
||||
// Attribute axes supports negative values from opset 11.
|
||||
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};
|
||||
|
|
@ -2158,6 +2180,41 @@ TEST(GradientCheckerTest, TopKGrad) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ClipGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Clip", kOnnxDomain, 12};
|
||||
|
||||
{
|
||||
TensorInfo x_info({2, 2, 2}, true);
|
||||
TensorInfo min_info({}, false);
|
||||
TensorInfo max_info({}, false);
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {2.8f}, {7.2f}};
|
||||
TensorInfo y_info({2, 2, 2}, true);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, min_info, max_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
{
|
||||
TensorInfo x_info({2, 2, 2}, true);
|
||||
TensorInfo min_info({}, false);
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}, {3.8f}};
|
||||
TensorInfo y_info({2, 2, 2}, true);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, min_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
// Should have a case with Op(x, null, max), but current ComputeGradientError doesn't support doing this.
|
||||
|
||||
{
|
||||
TensorInfo x_info({2, 2, 2}, true);
|
||||
std::vector<std::vector<float>> x_datas = {{1, 2, 3, 4, 5, 6, 7, 8}};
|
||||
TensorInfo y_info({2, 2, 2}, true);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info}, {y_info}, &max_error, x_datas);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -143,10 +143,12 @@ GradientCheckerTest.MatMulGrad
|
|||
GradientCheckerTest.ReduceMeanGrad
|
||||
GradientCheckerTest.ReduceSumGrad
|
||||
GradientCheckerTest.ReduceLogSumExpGrad
|
||||
GradientCheckerTest.ReduceL2Grad
|
||||
GradientCheckerTest.SoftmaxCrossEntropyGrad
|
||||
GradientCheckerTest.ExpandGrad
|
||||
GradientCheckerTest.DivGrad
|
||||
GradientCheckerTest.GemmGrad
|
||||
GradientCheckerTest.SplitGrad
|
||||
GradientCheckerTest.SqueezeGrad
|
||||
GradientCheckerTest.UnsqueezeGrad
|
||||
GradientCheckerTest.UnsqueezeGrad
|
||||
GradientCheckerTest.ClipGrad
|
||||
|
|
|
|||
Loading…
Reference in a new issue