ScatterNDGrad (#8261)

This commit is contained in:
Vincent Wang 2021-07-01 13:49:49 +08:00 committed by GitHub
parent 97f1eea2ea
commit ef8f50c4ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 121 additions and 2 deletions

View file

@ -1836,5 +1836,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetPadGradient) {
NodeDef("Pad", {GO(0), IA("Neg_pads")}, {GI(0)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetScatterNDGradient) {
std::vector<NodeDef> result;
if (IsGradientRequiredForSrcNodeInput(0)) {
result.emplace_back(NodeDef("Shape", {I(2)}, {IA("Shape_updates")}));
result.emplace_back(NodeDef("ConstantOfShape", {IA("Shape_updates")}, {IA("Zero_Shape_updates")},
{MakeAttribute("value", ScalarTensorProtoByElemType(0.0f, IElemType(0)))}));
result.emplace_back(NodeDef("ScatterND", {GO(0), I(1), IA("Zero_Shape_updates")}, {GI(0)}));
}
if (IsGradientRequiredForSrcNodeInput(2)) {
result.emplace_back(NodeDef("GatherND", {GO(0), I(1)}, {GI(2)}));
}
return result;
}
} // namespace training
} // namespace onnxruntime

View file

@ -77,6 +77,7 @@ DECLARE_GRADIENT_BUILDER(GetATenOpGradient)
DECLARE_GRADIENT_BUILDER(GetPadGradient)
DECLARE_GRADIENT_BUILDER(GetIdentityGradient)
DECLARE_GRADIENT_BUILDER(GetPythonOpGradient)
DECLARE_GRADIENT_BUILDER(GetScatterNDGradient)
} // namespace training
} // namespace onnxruntime

View file

@ -210,14 +210,19 @@ class GradientBuilderBase {
}
template <typename T>
static NodeDef ConstantScalarNode(T value, std::vector<int64_t> shape, const std::string& arg_name) {
static ONNX_NAMESPACE::TensorProto ScalarTensorProto(T value, std::vector<int64_t> shape) {
ORT_ENFORCE(shape.size() == 0 || (shape.size() == 1 && shape[0] == 1));
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(value);
for (auto dim : shape) {
t_proto.add_dims(dim);
}
return t_proto;
}
template <typename T>
static NodeDef ConstantScalarNode(T value, std::vector<int64_t> shape, const std::string& arg_name) {
auto t_proto = ScalarTensorProto(value, shape);
return NodeDef("Constant",
{},
{ArgDef(arg_name, nullptr)},
@ -237,6 +242,18 @@ class GradientBuilderBase {
return ConstantScalarNode(value, {1}, arg_name);
}
static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) {
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
return ScalarTensorProto(MLFloat16(math::floatToHalf(value)), {1});
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
return ScalarTensorProto(BFloat16(value), {1});
}
return ScalarTensorProto(value, {1});
}
static NodeDef ZeroConstantNode(int elem_type) {
return ConstantScalarNode(0.0f, "ZeroConstant_Type" + std::to_string(elem_type), elem_type);
}

View file

@ -109,6 +109,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Pad", GetPadGradient);
REGISTER_GRADIENT_BUILDER("Identity", GetIdentityGradient);
REGISTER_GRADIENT_BUILDER("PythonOp", GetPythonOpGradient);
REGISTER_GRADIENT_BUILDER("ScatterND", GetScatterNDGradient);
};
} // namespace training

View file

@ -2653,6 +2653,90 @@ TEST(GradientCheckerTest, PadGrad) {
}
#endif // USE_CUDA
TEST(GradientCheckerTest, ScatterNDGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"ScatterND", kOnnxDomain, 11};
{
TensorInfo data_info({8}, true);
TensorInfo indices_info({4, 1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo updates_info({4}, true);
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {4, 3, 1, 7}, {8, 9, 10, 11}};
TensorInfo output_info({8}, true);
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
{output_info}, &max_error, input_datas);
EXPECT_IS_TINY(max_error);
}
{
TensorInfo data_info({2, 2}, true);
TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo updates_info({2}, true);
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3}, {0, 0, 1, 1}, {4, 5}};
TensorInfo output_info({2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
{output_info}, &max_error, input_datas);
EXPECT_IS_TINY(max_error);
}
{
TensorInfo data_info({2, 2}, true);
TensorInfo indices_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo updates_info({2, 2}, true);
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3}, {1, 0}, {4, 5, 6, 7}};
TensorInfo output_info({2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
{output_info}, &max_error, input_datas);
EXPECT_IS_TINY(max_error);
}
{
TensorInfo data_info({2, 2, 2}, true);
TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo updates_info({2, 2}, true);
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1, 1, 0}, {8, 9, 10, 11}};
TensorInfo output_info({2, 2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
{output_info}, &max_error, input_datas);
EXPECT_IS_TINY(max_error);
}
{
TensorInfo data_info({2, 2, 2}, true);
TensorInfo indices_info({2, 1, 2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo updates_info({2, 1, 2}, true);
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1, 1, 0}, {8, 9, 10, 11}};
TensorInfo output_info({2, 2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
{output_info}, &max_error, input_datas);
EXPECT_IS_TINY(max_error);
}
{
TensorInfo data_info({2, 2, 2}, true);
TensorInfo indices_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());
TensorInfo updates_info({2, 2, 2}, true);
std::vector<std::vector<float>> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1}, {8, 9, 10, 11, 12, 13, 14, 15}};
TensorInfo output_info({2, 2, 2}, true);
gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info},
{output_info}, &max_error, input_datas);
EXPECT_IS_TINY(max_error);
}
}
} // namespace test
} // namespace onnxruntime