diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index cde2cb5caf..7301920a54 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -36,7 +36,10 @@ static std::unordered_map> {"Slice", {1, 2, 3, 4}}, {"SparseSoftmaxCrossEntropy", {1, 2}}, {"ConstantOfShape", {0}}, - {"Scatter", {1}}}; + {"Scatter", {1}}, + {"OneHot", {0, 1, 2}}, + {"Where", {0}}, + {"Range", {0, 1, 2}}}; class GradientGraphBuilder { public: diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index dd4245ac93..07f1024736 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -920,6 +920,22 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) { {GI(0)})}; } +IMPLEMENT_GRADIENT_BUILDER(GetWhereGradient) { + std::vector result; + const int64_t data_type = static_cast(I(1).type_proto->tensor_type().elem_type()); + if (IsGradientRequiredForSrcNodeInput(1)) { + result.push_back(NodeDef("Cast", {I(0)}, {IA("Positive_Mask")}, {MakeAttribute("to", data_type)})); + result.push_back(NodeDef("Mul", {GO(0), IA("Positive_Mask")}, {GI(1)})); + } + + if (IsGradientRequiredForSrcNodeInput(2)) { + result.push_back(NodeDef("Not", {I(0)}, {IA("Not_Condition", IType(0))})); + result.push_back(NodeDef("Cast", {IA("Not_Condition")}, {IA("Negative_Mask")}, {MakeAttribute("to", data_type)})); + result.push_back(NodeDef("Mul", {GO(0), IA("Negative_Mask")}, {GI(2)})); + } + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetSendGradient) { // Send inputs: signal A, remote, data; outputs: signal B // Recv inputs: signal B, remote; outputs: signal A', data' diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index defd029fe0..3de555b3de 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -50,6 +50,7 @@ DECLARE_GRADIENT_BUILDER(GetMegatronFGradient) DECLARE_GRADIENT_BUILDER(GetMegatronGGradient) DECLARE_GRADIENT_BUILDER(GetSliceGradient) DECLARE_GRADIENT_BUILDER(GetFastGeluGradient) +DECLARE_GRADIENT_BUILDER(GetWhereGradient) DECLARE_GRADIENT_BUILDER(GetSendGradient) DECLARE_GRADIENT_BUILDER(GetRecvGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 41189d6d92..057e896ebc 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -11,7 +11,6 @@ namespace training { GradientDef GetGradientForOp(const Node* node, const std::unordered_set& output_args_need_grad, const std::unordered_set& input_args_need_grad) { - // REVIEW(mzs): The below condition does not seem correct, it needs to be >= GRADIENT_OP_VERSION // but changing it will break bunch of tests since many operators like sqrt are version 6, // yet have a grad operator. However changing the opset requires changing the operator @@ -23,7 +22,7 @@ GradientDef GetGradientForOp(const Node* node, // REVIEW(bahuang): We don't have a version control for forward to backward op mapping. // Current SliceGrad(kMSDomain, 1) only supports Slice(kOnnxDomain, 10/11) because adding grad operator for versions // less than 9 is not supported and for Slice we have Slice-1, Slice-10 and Slice-11. - + /*ORT_ENFORCE( node->Op()->SinceVersion() <= GRADIENT_OP_VERSION, "Gradients are supported for opset version" + std::to_string(node->Op()->SinceVersion()) + @@ -89,6 +88,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("MegatronG", GetMegatronGGradient); REGISTER_GRADIENT_BUILDER("Slice", GetSliceGradient); REGISTER_GRADIENT_BUILDER("FastGelu", GetFastGeluGradient); + REGISTER_GRADIENT_BUILDER("Where", GetWhereGradient); REGISTER_GRADIENT_BUILDER("Send", GetSendGradient); REGISTER_GRADIENT_BUILDER("Recv", GetRecvGradient); }; diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 660ede3049..05706779bd 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -68,7 +68,6 @@ inline std::vector GradientChecker::EvaluateFunctionA const std::vector& y_infos, std::vector>* x_datas, std::vector>* y_datas) { - // clear OpTester input/output/initializer_index op_session.ClearData(); @@ -84,6 +83,12 @@ inline std::vector GradientChecker::EvaluateFunctionA std::vector int32_data(data.size()); std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast(x); }); op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data); + } else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { + std::unique_ptr p_data(new bool[data.size()]); + for (size_t i = 0; i < data.size(); ++i) { + p_data[i] = static_cast(data[i]); + } + op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size()); } else { op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), data); } @@ -140,6 +145,12 @@ inline Status GradientChecker::ComputeTheoreticalJacobianTransp std::vector int32_data(data.size()); std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast(x); }); op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data); + } else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { + std::unique_ptr p_data(new bool[data.size()]); + for (size_t i = 0; i < data.size(); ++i) { + p_data[i] = static_cast(data[i]); + } + op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size()); } else { op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), data); } @@ -193,7 +204,6 @@ inline Status GradientChecker::InitOpTesterWithGraph( std::vector>* y_datas, const std::vector& attributes, const std::unordered_map& extra_domain_to_version) { - for (size_t data_index = 0; data_index < x_datas->size(); data_index++) { std::string name = "input" + std::to_string(data_index); const std::vector& data = (*x_datas)[data_index]; @@ -206,6 +216,12 @@ inline Status GradientChecker::InitOpTesterWithGraph( std::vector int32_data(data.size()); std::transform(data.begin(), data.end(), int32_data.begin(), [](X_T x) { return static_cast(x); }); op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), int32_data); + } else if (x_infos[data_index].data_type == DataTypeImpl::GetTensorType()) { + std::unique_ptr p_data(new bool[data.size()]); + for (size_t i = 0; i < data.size(); ++i) { + p_data[i] = static_cast(data[i]); + } + op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), p_data.get(), data.size()); } else { op_session.AddInput(name.c_str(), x_infos[data_index].shape.GetDims(), data); } @@ -228,8 +244,8 @@ inline Status GradientChecker::InitOpTesterWithGraph( status = graph.Resolve(); if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << "Resolve failed with status: " << status.ErrorMessage(); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + LOGS_DEFAULT(ERROR) << "Resolve failed with status: " << status.ErrorMessage(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); } if (!status.IsOK()) { @@ -251,7 +267,6 @@ inline Status GradientChecker::InitOpTesterWithGradGraph( std::vector>* x_datas, std::vector>* y_datas, const std::vector& attributes) { - std::unordered_map extra_domain_to_version{{kMSDomain, 1}, {kOnnxDomain, 9}}; InitOpTesterWithGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes, extra_domain_to_version); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 2e1cca4c8c..0dd17510e2 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1446,6 +1446,23 @@ TEST(GradientUtilsTest, ZeroGradientFloat16) { } #endif +TEST(GradientCheckerTest, WhereGrad) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"Where"}; + + std::vector shape{4, 3, 2}; + TensorInfo x_info(shape), y_info(shape); + std::function transformer = [](float x) { + return static_cast(std::fmod(std::fabs(x), 1.0f) > 0.5f); + }; + TensorInfo condition_info(shape, false, &transformer, DataTypeImpl::GetTensorType()); + + TensorShape output_shape{shape}; + gradient_checker.ComputeGradientError(op_def, {condition_info, x_info, y_info}, {output_shape}, &max_error); + EXPECT_IS_TINY(max_error); +} + TEST(GradientCheckerTest, SliceGrad) { float max_error; GradientChecker gradient_checker;