diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 534d5c0a8e..9ad11f56a2 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1836,5 +1836,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetPadGradient) { NodeDef("Pad", {GO(0), IA("Neg_pads")}, {GI(0)})}; } +IMPLEMENT_GRADIENT_BUILDER(GetScatterNDGradient) { + std::vector result; + if (IsGradientRequiredForSrcNodeInput(0)) { + result.emplace_back(NodeDef("Shape", {I(2)}, {IA("Shape_updates")})); + result.emplace_back(NodeDef("ConstantOfShape", {IA("Shape_updates")}, {IA("Zero_Shape_updates")}, + {MakeAttribute("value", ScalarTensorProtoByElemType(0.0f, IElemType(0)))})); + result.emplace_back(NodeDef("ScatterND", {GO(0), I(1), IA("Zero_Shape_updates")}, {GI(0)})); + } + + if (IsGradientRequiredForSrcNodeInput(2)) { + result.emplace_back(NodeDef("GatherND", {GO(0), I(1)}, {GI(2)})); + } + + return result; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 7027d62751..3696c23000 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -77,6 +77,7 @@ DECLARE_GRADIENT_BUILDER(GetATenOpGradient) DECLARE_GRADIENT_BUILDER(GetPadGradient) DECLARE_GRADIENT_BUILDER(GetIdentityGradient) DECLARE_GRADIENT_BUILDER(GetPythonOpGradient) +DECLARE_GRADIENT_BUILDER(GetScatterNDGradient) } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index 19a73bf77f..b7fe60ed0b 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -210,14 +210,19 @@ class GradientBuilderBase { } template - static NodeDef ConstantScalarNode(T value, std::vector shape, const std::string& arg_name) { + static ONNX_NAMESPACE::TensorProto ScalarTensorProto(T value, std::vector shape) { ORT_ENFORCE(shape.size() == 0 || (shape.size() == 1 && shape[0] == 1)); - auto t_proto = ONNX_NAMESPACE::ToTensor(value); for (auto dim : shape) { t_proto.add_dims(dim); } + return t_proto; + } + + template + static NodeDef ConstantScalarNode(T value, std::vector shape, const std::string& arg_name) { + auto t_proto = ScalarTensorProto(value, shape); return NodeDef("Constant", {}, {ArgDef(arg_name, nullptr)}, @@ -237,6 +242,18 @@ class GradientBuilderBase { return ConstantScalarNode(value, {1}, arg_name); } + static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) { + if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + return ScalarTensorProto(MLFloat16(math::floatToHalf(value)), {1}); + } + + if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + return ScalarTensorProto(BFloat16(value), {1}); + } + + return ScalarTensorProto(value, {1}); + } + static NodeDef ZeroConstantNode(int elem_type) { return ConstantScalarNode(0.0f, "ZeroConstant_Type" + std::to_string(elem_type), elem_type); } diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 294de6cb9e..ae3db9d766 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -109,6 +109,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Pad", GetPadGradient); REGISTER_GRADIENT_BUILDER("Identity", GetIdentityGradient); REGISTER_GRADIENT_BUILDER("PythonOp", GetPythonOpGradient); + REGISTER_GRADIENT_BUILDER("ScatterND", GetScatterNDGradient); }; } // namespace training diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 64f1b16b93..2b67ea705e 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -2653,6 +2653,90 @@ TEST(GradientCheckerTest, PadGrad) { } #endif // USE_CUDA +TEST(GradientCheckerTest, ScatterNDGrad) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"ScatterND", kOnnxDomain, 11}; + + { + TensorInfo data_info({8}, true); + TensorInfo indices_info({4, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo updates_info({4}, true); + std::vector> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {4, 3, 1, 7}, {8, 9, 10, 11}}; + + TensorInfo output_info({8}, true); + + gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info}, + {output_info}, &max_error, input_datas); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo data_info({2, 2}, true); + TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo updates_info({2}, true); + std::vector> input_datas = {{0, 1, 2, 3}, {0, 0, 1, 1}, {4, 5}}; + + TensorInfo output_info({2, 2}, true); + + gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info}, + {output_info}, &max_error, input_datas); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo data_info({2, 2}, true); + TensorInfo indices_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo updates_info({2, 2}, true); + std::vector> input_datas = {{0, 1, 2, 3}, {1, 0}, {4, 5, 6, 7}}; + + TensorInfo output_info({2, 2}, true); + + gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info}, + {output_info}, &max_error, input_datas); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo data_info({2, 2, 2}, true); + TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo updates_info({2, 2}, true); + std::vector> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1, 1, 0}, {8, 9, 10, 11}}; + + TensorInfo output_info({2, 2, 2}, true); + + gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info}, + {output_info}, &max_error, input_datas); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo data_info({2, 2, 2}, true); + TensorInfo indices_info({2, 1, 2}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo updates_info({2, 1, 2}, true); + std::vector> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1, 1, 0}, {8, 9, 10, 11}}; + + TensorInfo output_info({2, 2, 2}, true); + + gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info}, + {output_info}, &max_error, input_datas); + EXPECT_IS_TINY(max_error); + } + + { + TensorInfo data_info({2, 2, 2}, true); + TensorInfo indices_info({2, 1}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo updates_info({2, 2, 2}, true); + std::vector> input_datas = {{0, 1, 2, 3, 4, 5, 6, 7}, {0, 1}, {8, 9, 10, 11, 12, 13, 14, 15}}; + + TensorInfo output_info({2, 2, 2}, true); + + gradient_checker.ComputeGradientError(op_def, {data_info, indices_info, updates_info}, + {output_info}, &max_error, input_datas); + EXPECT_IS_TINY(max_error); + } +} + } // namespace test } // namespace onnxruntime