diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 4723246fcf..5fee182fb1 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -60,7 +60,6 @@ namespace cuda { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); - // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template template diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index e1bac6ddab..47287269eb 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -598,7 +598,7 @@ TEST(ReductionOpTest, ReduceMax_int32) { #if defined(OPENVINO_CONFIG_GPU_FP32) test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Disabled temporarily #else - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: axis must be 0 + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: axis must be 0 #endif } diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index 6a8c7830e9..8d1d8282d9 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -404,7 +404,7 @@ class OpTester { template void AddAttribute(std::string name, T value) { // Generate a the proper AddAttribute call for later - add_attribute_funcs_.emplace_back([name = std::move(name), value = std::move(value)](onnxruntime::Node& node) { + add_attribute_funcs_.emplace_back([ name = std::move(name), value = std::move(value) ](onnxruntime::Node & node) { node.AddAttribute(name, value); }); } @@ -541,7 +541,7 @@ class OpTester { value.Init(p_tensor.release(), DataTypeImpl::GetType(), DataTypeImpl::GetType()->GetDeleteFunc()); auto node_arg = NodeArg(name, &type_proto.proto); - if (dim_params && !(dim_params->empty())) { + if (dim_params && !(dim_params->empty()) && add_shape_to_tensor_data_) { // If dim_params presents, configure node_arg's dim value based on dim_params, which supports symbolic dim and dim broadcast. auto& dim_params_data = *dim_params; onnx::TensorShapeProto new_shape; diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index a823d46ba9..eb6ea80457 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -62,192 +62,221 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) { std::vector result; ArgDef A = I(0), B = I(1), Y = O(0); - std::vector A_shape = GetShape(A); - std::vector B_shape = GetShape(B); - std::vector Y_shape = GetShape(Y); + std::vector A_shape, B_shape, Y_shape; + if (GetShape(A, A_shape).IsOK() && GetShape(B, B_shape).IsOK() && GetShape(Y, Y_shape).IsOK()) { + std::vector shared_attributes; + shared_attributes.push_back(MakeAttribute("beta", float(0))); + AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1)); + AttributeProto transpose_second_input = MakeAttribute("transB", int64_t(1)); - std::vector shared_attributes; - shared_attributes.push_back(MakeAttribute("beta", float(0))); - AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1)); - AttributeProto transpose_second_input = MakeAttribute("transB", int64_t(1)); + if (A_shape.size() == 2 && B_shape.size() == 2) { + NodeDef zero_constant_node = ZeroConstantNode(); + ArgDef ZERO = zero_constant_node.output_args[0]; + result.push_back(zero_constant_node); - if (A_shape.size() == 2 && B_shape.size() == 2) { - NodeDef zero_constant_node = ZeroConstantNode(); - ArgDef ZERO = zero_constant_node.output_args[0]; - result.push_back(zero_constant_node); - - // is GI(0) required - if (IsGradientRequiredForSrcNodeInput(0)) { - // dA = dY * B' - std::vector attrs(shared_attributes); - attrs.push_back(transpose_second_input); - result.push_back( - NodeDef("Gemm", - {GO(0), B, ZERO}, - {GI(0)}, - attrs)); - } - - // is GI(1) required - if (IsGradientRequiredForSrcNodeInput(1)) { - // dB = A' * dY - std::vector attrs(shared_attributes); - attrs.push_back(transpose_first_input); - result.push_back( - NodeDef("Gemm", - {A, GO(0), ZERO}, - {GI(1)}, - attrs)); - } - } else if (A_shape.size() > 2 || B_shape.size() > 2) { - if (IsGradientRequiredForSrcNodeInput(0)) { - // If B_shape.size() == 2, dA is computed through 2 ops: transpose and matmul. - // It can be replaced with Gemm(dY_reshape, B_transpose) and reshape. - // However, there is a performance degradation. - // Thus this implementation is not implemented. - int64_t B_rank = B_shape.size(); - std::vector B_perm(B_rank); - std::iota(B_perm.begin(), B_perm.end(), 0); - std::swap(B_perm[B_rank - 1], B_perm[B_rank - 2]); - - std::vector output_shape; - for (size_t i = 0; i < Y_shape.size() - 1; i++) { - output_shape.push_back(Y_shape[i]); + // is GI(0) required + if (IsGradientRequiredForSrcNodeInput(0)) { + // dA = dY * B' + std::vector attrs(shared_attributes); + attrs.push_back(transpose_second_input); + result.push_back( + NodeDef("Gemm", + {GO(0), B, ZERO}, + {GI(0)}, + attrs)); } - output_shape.push_back(B_shape[B_shape.size() - 2]); - - std::vector A_axes; - ComputeBroadcastBackwardAxes(A_shape, output_shape, &A_axes, nullptr); - - result.push_back( - NodeDef("Transpose", - {B}, - {IA("B_t")}, - {MakeAttribute("perm", B_perm)})); - - ArgDef matmul_out = A_axes.size() > 0 ? IA("PreReduceGrad0") : GI(0); - - result.push_back( - NodeDef("MatMul", - {GO(0), IA("B_t")}, - {matmul_out})); - - if (A_axes.size() > 0) { - result.push_back( - NodeDef("ReduceSum", - {IA("PreReduceGrad0")}, - {IA("ReduceGrad0")}, - {{"keepdims", MakeAttribute("keepdims", int64_t(1))}, - {"axes", MakeAttribute("axes", A_axes)}})); - - result.push_back( - NodeDef("Shape", - {A}, - {IA("A_shape")})); - - result.push_back( - NodeDef("Reshape", - {IA("ReduceGrad0"), IA("A_shape")}, - {GI(0)})); - } - } - if (IsGradientRequiredForSrcNodeInput(1)) { - if (B_shape.size() == 2 && - (B_shape[0].has_dim_value() || A_shape[A_shape.size() - 1].has_dim_value()) && - (B_shape[1].has_dim_value() || Y_shape[Y_shape.size() - 1].has_dim_value())) { - // A[M, K], B[K, N], Y[M, N] - int64_t K, N; - if (B_shape[0].has_dim_value()) { - K = B_shape[0].dim_value(); - } else { - K = A_shape[A_shape.size() - 1].dim_value(); - } - if (B_shape[1].has_dim_value()) { - N = B_shape[1].dim_value(); - } else { - N = Y_shape[Y_shape.size() - 1].dim_value(); - } - - std::vector A_shape_2d{-1, K}; - NodeDef A_shape_2d_node = ConstantValueNode(A_shape_2d, Name("A_shape_2d")); - ArgDef A_shape_2d_arg = A_shape_2d_node.output_args[0]; - result.push_back(A_shape_2d_node); - - std::vector dY_shape_2d{-1, N}; - NodeDef dY_shape_2d_node = ConstantValueNode(dY_shape_2d, Name("dY_shape_2d")); - ArgDef dY_shape_2d_arg = dY_shape_2d_node.output_args[0]; - result.push_back(dY_shape_2d_node); - - NodeDef zero_constant_node = ZeroConstantNode(); - ArgDef ZERO = zero_constant_node.output_args[0]; - result.push_back(zero_constant_node); - - result.push_back( - NodeDef("Reshape", - {A, A_shape_2d_arg}, - {IA("A_reshape_2d")})); - result.push_back( - NodeDef("Reshape", - {GO(0), dY_shape_2d_arg}, - {IA("dY_reshape_2d")})); + // is GI(1) required + if (IsGradientRequiredForSrcNodeInput(1)) { // dB = A' * dY std::vector attrs(shared_attributes); attrs.push_back(transpose_first_input); result.push_back( NodeDef("Gemm", - {IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO}, + {A, GO(0), ZERO}, {GI(1)}, attrs)); - } else { - int64_t A_rank = A_shape.size(); - std::vector A_perm(A_rank); - std::iota(A_perm.begin(), A_perm.end(), 0); - std::swap(A_perm[A_rank - 1], A_perm[A_rank - 2]); + } + } else if (A_shape.size() > 2 || B_shape.size() > 2) { + if (IsGradientRequiredForSrcNodeInput(0)) { + // If B_shape.size() == 2, dA is computed through 2 ops: transpose and matmul. + // It can be replaced with Gemm(dY_reshape, B_transpose) and reshape. + // However, there is a performance degradation. + // Thus this implementation is not implemented. + int64_t B_rank = B_shape.size(); + std::vector B_perm(B_rank); + std::iota(B_perm.begin(), B_perm.end(), 0); + std::swap(B_perm[B_rank - 1], B_perm[B_rank - 2]); std::vector output_shape; - for (size_t i = 0; i < Y_shape.size() - 2; i++) { + for (size_t i = 0; i < Y_shape.size() - 1; i++) { output_shape.push_back(Y_shape[i]); } - output_shape.push_back(A_shape[A_shape.size() - 1]); - output_shape.push_back(Y_shape[Y_shape.size() - 1]); + output_shape.push_back(B_shape[B_shape.size() - 2]); - std::vector B_axes; - ComputeBroadcastBackwardAxes(B_shape, output_shape, &B_axes, nullptr); + std::vector A_axes; + ComputeBroadcastBackwardAxes(A_shape, output_shape, &A_axes, nullptr); result.push_back( NodeDef("Transpose", - {A}, - {IA("A_t")}, - {MakeAttribute("perm", A_perm)})); + {B}, + {IA("B_t")}, + {MakeAttribute("perm", B_perm)})); - ArgDef matmul_out = B_axes.size() > 0 ? IA("PreReduceGrad1") : GI(1); + ArgDef matmul_out = A_axes.size() > 0 ? IA("PreReduceGrad0") : GI(0); result.push_back( NodeDef("MatMul", - {IA("A_t"), GO(0)}, + {GO(0), IA("B_t")}, {matmul_out})); - if (B_axes.size() > 0) { + if (A_axes.size() > 0) { result.push_back( NodeDef("ReduceSum", - {IA("PreReduceGrad1")}, - {IA("ReduceGrad1")}, - {{"keepdims", MakeAttribute("keepdims", int64_t(0))}, - {"axes", MakeAttribute("axes", B_axes)}})); + {IA("PreReduceGrad0")}, + {IA("ReduceGrad0")}, + {{"keepdims", MakeAttribute("keepdims", int64_t(1))}, + {"axes", MakeAttribute("axes", A_axes)}})); + result.push_back( NodeDef("Shape", - {B}, - {IA("B_shape")})); + {A}, + {IA("A_shape")})); + result.push_back( NodeDef("Reshape", - {IA("ReduceGrad1"), IA("B_shape")}, - {GI(1)})); + {IA("ReduceGrad0"), IA("A_shape")}, + {GI(0)})); + } + } + if (IsGradientRequiredForSrcNodeInput(1)) { + if (B_shape.size() == 2 && + (B_shape[0].has_dim_value() || A_shape[A_shape.size() - 1].has_dim_value()) && + (B_shape[1].has_dim_value() || Y_shape[Y_shape.size() - 1].has_dim_value())) { + // A[M, K], B[K, N], Y[M, N] + int64_t K, N; + if (B_shape[0].has_dim_value()) { + K = B_shape[0].dim_value(); + } else { + K = A_shape[A_shape.size() - 1].dim_value(); + } + if (B_shape[1].has_dim_value()) { + N = B_shape[1].dim_value(); + } else { + N = Y_shape[Y_shape.size() - 1].dim_value(); + } + + std::vector A_shape_2d{-1, K}; + NodeDef A_shape_2d_node = ConstantValueNode(A_shape_2d, Name("A_shape_2d")); + ArgDef A_shape_2d_arg = A_shape_2d_node.output_args[0]; + result.push_back(A_shape_2d_node); + + std::vector dY_shape_2d{-1, N}; + NodeDef dY_shape_2d_node = ConstantValueNode(dY_shape_2d, Name("dY_shape_2d")); + ArgDef dY_shape_2d_arg = dY_shape_2d_node.output_args[0]; + result.push_back(dY_shape_2d_node); + + NodeDef zero_constant_node = ZeroConstantNode(); + ArgDef ZERO = zero_constant_node.output_args[0]; + result.push_back(zero_constant_node); + + result.push_back( + NodeDef("Reshape", + {A, A_shape_2d_arg}, + {IA("A_reshape_2d")})); + result.push_back( + NodeDef("Reshape", + {GO(0), dY_shape_2d_arg}, + {IA("dY_reshape_2d")})); + + // dB = A' * dY + std::vector attrs(shared_attributes); + attrs.push_back(transpose_first_input); + result.push_back( + NodeDef("Gemm", + {IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO}, + {GI(1)}, + attrs)); + } else { + int64_t A_rank = A_shape.size(); + std::vector A_perm(A_rank); + std::iota(A_perm.begin(), A_perm.end(), 0); + std::swap(A_perm[A_rank - 1], A_perm[A_rank - 2]); + + std::vector output_shape; + for (size_t i = 0; i < Y_shape.size() - 2; i++) { + output_shape.push_back(Y_shape[i]); + } + output_shape.push_back(A_shape[A_shape.size() - 1]); + output_shape.push_back(Y_shape[Y_shape.size() - 1]); + + std::vector B_axes; + ComputeBroadcastBackwardAxes(B_shape, output_shape, &B_axes, nullptr); + + result.push_back( + NodeDef("Transpose", + {A}, + {IA("A_t")}, + {MakeAttribute("perm", A_perm)})); + + ArgDef matmul_out = B_axes.size() > 0 ? IA("PreReduceGrad1") : GI(1); + + result.push_back( + NodeDef("MatMul", + {IA("A_t"), GO(0)}, + {matmul_out})); + + if (B_axes.size() > 0) { + result.push_back( + NodeDef("ReduceSum", + {IA("PreReduceGrad1")}, + {IA("ReduceGrad1")}, + {{"keepdims", MakeAttribute("keepdims", int64_t(0))}, + {"axes", MakeAttribute("axes", B_axes)}})); + result.push_back( + NodeDef("Shape", + {B}, + {IA("B_shape")})); + result.push_back( + NodeDef("Reshape", + {IA("ReduceGrad1"), IA("B_shape")}, + {GI(1)})); + } } } } } else { - ORT_THROW("Matmul Gradient Builder shouldn't reach here. "); + //GetShape failed, build shape-independent gradient graph + ArgDef a_axes, b_axes, a_shape, b_shape, ia_shape; + a_shape = IA("Shape_" + A.name); + b_shape = IA("Shape_" + B.name); + + if (IsGradientRequiredForSrcNodeInput(0)) { + ArgDef pre_reduce_grad_0 = IA("PreReduceGrad0"); + result.push_back( + NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1}, + {GO(0), B}, + {pre_reduce_grad_0}, + {{"transB", MakeAttribute("transB", int64_t(1))}})); + + a_axes = IA("ReduceAxes_" + A.name + "_for_" + A.name); + ia_shape = IA("Shape_" + pre_reduce_grad_0.name); + ComputeBroadcastBackwardAxesDynamic(A, pre_reduce_grad_0, a_shape, ia_shape, &a_axes, nullptr, result); + HandleBroadcastingDynamic(pre_reduce_grad_0, A, a_shape, GI(0), a_axes, result); + } + if (IsGradientRequiredForSrcNodeInput(1)) { + ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1"); + result.push_back( + NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1}, + {A, GO(0)}, + {pre_reduce_grad_1}, + {{"transA", MakeAttribute("transA", int64_t(1))}})); + + b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name); + ia_shape = IA("Shape_" + pre_reduce_grad_1.name); + ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result); + HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result); + } } return result; @@ -346,15 +375,51 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { bool has_beta = attributes.at("beta").has_f(); float beta = attributes.at("beta").f(); ORT_ENFORCE(beta != 0.0f); + std::vector C_shape, dY_shape; + if (GetShape(C, C_shape).IsOK() && GetShape(dY, dY_shape).IsOK()) { + std::vector C_axes, dY_axes; + ComputeBroadcastBackwardAxes(C_shape, dY_shape, &C_axes, &dY_axes); - std::vector C_shape = GetShape(C); - std::vector dY_shape = GetShape(dY); + if (C_axes.size() > 0) { + HandleBroadcasting(dY, C, IA("dC_reduced"), C_axes, result); - std::vector C_axes, dY_axes; - ComputeBroadcastBackwardAxes(C_shape, dY_shape, &C_axes, &dY_axes); + if (has_beta && beta != 1.0f) { + NodeDef scale_node = ConstantValueNode(beta, Name("Scale")); + ArgDef SCALE = scale_node.output_args[0]; + result.push_back(scale_node); + result.push_back( + NodeDef("Mul", + {IA("dC_reduced"), SCALE}, + {dC})); + } else { + result.push_back( + NodeDef("Identity", {IA("dC_reduced")}, {dC})); + } + } else { + if (has_beta && beta != 1.0f) { + NodeDef scale_node = ConstantValueNode(beta, Name("Scale")); + ArgDef SCALE = scale_node.output_args[0]; + result.push_back(scale_node); + result.push_back( + NodeDef("Mul", + {dY, SCALE}, + {dC})); + } else { + result.push_back( + NodeDef("Identity", + {dY}, + {dC})); + } + } + } else { + //GetShape failed, build shape-independent gradient graph + ArgDef c_axes = IA("ReduceAxes_" + C.name); + ArgDef c_shape = IA("Shape_" + C.name); + ArgDef dy_shape = IA("Shape_" + dY.name); - if (C_axes.size() > 0) { - HandleBroadcasting(dY, C, IA("dC_reduced"), C_axes, result); + ComputeBroadcastBackwardAxesDynamic(C, dY, c_shape, dy_shape, &c_axes, nullptr, result); + + HandleBroadcastingDynamic(dY, C, c_shape, IA("dC_reduced"), c_axes, result); if (has_beta && beta != 1.0f) { NodeDef scale_node = ConstantValueNode(beta, Name("Scale")); @@ -368,21 +433,6 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { result.push_back( NodeDef("Identity", {IA("dC_reduced")}, {dC})); } - } else { - if (has_beta && beta != 1.0f) { - NodeDef scale_node = ConstantValueNode(beta, Name("Scale")); - ArgDef SCALE = scale_node.output_args[0]; - result.push_back(scale_node); - result.push_back( - NodeDef("Mul", - {dY, SCALE}, - {dC})); - } else { - result.push_back( - NodeDef("Identity", - {dY}, - {dC})); - } } } return result; @@ -419,7 +469,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) { std::vector split_attribute(GetSrcNodeInputSize()); std::vector outputs; for (int i = 0; i < GetSrcNodeInputSize(); ++i) { - std::vector data_shape = GetShape(I(i)); + std::vector data_shape; + ORT_ENFORCE(GetShape(I(i), data_shape).IsOK()); int64_t axis_index = axis < 0 ? static_cast(data_shape.size()) + axis : axis; if (axis_index >= 0 && axis_index < static_cast(data_shape.size()) && data_shape[axis_index].has_dim_value()) { split_attribute[i] = data_shape[axis_index].dim_value(); @@ -464,10 +515,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetReshapeGradient) { IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) { std::vector bw_perm; auto attributes = SrcNodeAttributes(); + std::vector new_attributes; if (attributes.empty()) { const TensorShapeProto& input_shape = I(0).type_proto->tensor_type().shape(); - for (int i = input_shape.dim_size() - 1; i >= 0; --i) { - bw_perm.push_back(i); + if (input_shape.dim_size() > 0) { //input_shape is available + int n = input_shape.dim_size() - 1; + bw_perm.resize(n + 1); + std::generate(bw_perm.begin(), bw_perm.end(), [&n] { return n--; }); + new_attributes.push_back(MakeAttribute("perm", bw_perm)); } } else { auto fw_perm = RetrieveValues(attributes.at("perm")); @@ -476,13 +531,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) { for (int i = 0; i < static_cast(size); ++i) { bw_perm[fw_perm[i]] = i; } + new_attributes.push_back(MakeAttribute("perm", bw_perm)); } return std::vector{ NodeDef("Transpose", {GO(0)}, {GI(0)}, - {MakeAttribute("perm", bw_perm)})}; + new_attributes)}; } IMPLEMENT_GRADIENT_BUILDER(GetAveragePoolGradient) { @@ -624,30 +680,62 @@ IMPLEMENT_GRADIENT_BUILDER(GetAddSubGradient) { bool is_sub = (SrcNodeOpType() == "Sub"); const ArgDef a = I(0), b = I(1); - - std::vector a_shape = GetShape(a); - std::vector b_shape = GetShape(b); - - std::vector a_axes, b_axes; - ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes); - std::vector output; - - if (IsGradientRequiredForSrcNodeInput(0)) { - if (a_axes.size() > 0) { - HandleBroadcasting(GO(0), a, GI(0), a_axes, output); - } else { - output.push_back( - NodeDef("Identity", - {GO(0)}, - {GI(0)})); + std::vector a_shape, b_shape; + if (GetShape(a, a_shape).IsOK() && GetShape(b, b_shape).IsOK()) { + std::vector a_axes, b_axes; + ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes); + if (IsGradientRequiredForSrcNodeInput(0)) { + if (a_axes.size() > 0) { + HandleBroadcasting(GO(0), a, GI(0), a_axes, output); + } else { + output.push_back( + NodeDef("Identity", + {GO(0)}, + {GI(0)})); + } } - } - if (IsGradientRequiredForSrcNodeInput(1)) { - if (b_axes.size() > 0) { + if (IsGradientRequiredForSrcNodeInput(1)) { + if (b_axes.size() > 0) { + ArgDef reshape_output = is_sub ? IA("ReshapeReduceSum_2", IType(1)) : GI(1); + HandleBroadcasting(GO(0), b, reshape_output, b_axes, output); + + if (is_sub) { + output.push_back( + NodeDef("Neg", + {reshape_output}, + {GI(1)})); + } + } else { + if (is_sub) { + output.push_back( + NodeDef("Neg", + {GO(0)}, + {GI(1)})); + } else /*is_add*/ { + output.push_back( + NodeDef("Identity", + {GO(0)}, + {GI(1)})); + } + } + } + } else { + //GetShape failed, build shape-independent gradient graph + ArgDef a_axes = IA("ReduceAxes_" + a.name); + ArgDef b_axes = IA("ReduceAxes_" + b.name); + ArgDef A_shape = IA("Shape_" + a.name); + ArgDef B_shape = IA("Shape_" + b.name); + ComputeBroadcastBackwardAxesDynamic(a, b, A_shape, B_shape, &a_axes, &b_axes, output); + + if (IsGradientRequiredForSrcNodeInput(0)) { + HandleBroadcastingDynamic(GO(0), a, A_shape, GI(0), a_axes, output); + } + + if (IsGradientRequiredForSrcNodeInput(1)) { ArgDef reshape_output = is_sub ? IA("ReshapeReduceSum_2", IType(1)) : GI(1); - HandleBroadcasting(GO(0), b, reshape_output, b_axes, output); + HandleBroadcastingDynamic(GO(0), b, B_shape, reshape_output, b_axes, output); if (is_sub) { output.push_back( @@ -655,18 +743,6 @@ IMPLEMENT_GRADIENT_BUILDER(GetAddSubGradient) { {reshape_output}, {GI(1)})); } - } else { - if (is_sub) { - output.push_back( - NodeDef("Neg", - {GO(0)}, - {GI(1)})); - } else /*is_add*/ { - output.push_back( - NodeDef("Identity", - {GO(0)}, - {GI(1)})); - } } } return output; @@ -675,44 +751,70 @@ IMPLEMENT_GRADIENT_BUILDER(GetAddSubGradient) { IMPLEMENT_GRADIENT_BUILDER(GetMulGradient) { const ArgDef a = I(0), b = I(1); - std::vector a_shape = GetShape(a); - std::vector b_shape = GetShape(b); - std::vector a_axes, b_axes; - ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes); - std::vector output; + std::vector a_shape, b_shape; + if (GetShape(a, a_shape).IsOK() && GetShape(b, b_shape).IsOK()) { + std::vector a_axes, b_axes; + ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes); - if (IsGradientRequiredForSrcNodeInput(0)) { - output.push_back( - NodeDef("Mul", - {GO(0), I(1)}, - {IA("PreReduceGrad0", OType(0))})); - - if (a_axes.size() > 0) { - HandleBroadcasting(IA("PreReduceGrad0", OType(0)), a, GI(0), a_axes, output); - } else { + if (IsGradientRequiredForSrcNodeInput(0)) { output.push_back( - NodeDef("Identity", - {IA("PreReduceGrad0", OType(0))}, - {GI(0)})); + NodeDef("Mul", + {GO(0), I(1)}, + {IA("PreReduceGrad0", OType(0))})); + + if (a_axes.size() > 0) { + HandleBroadcasting(IA("PreReduceGrad0", OType(0)), a, GI(0), a_axes, output); + } else { + output.push_back( + NodeDef("Identity", + {IA("PreReduceGrad0", OType(0))}, + {GI(0)})); + } + } + + if (IsGradientRequiredForSrcNodeInput(1)) { + output.push_back( + NodeDef("Mul", + {GO(0), I(0)}, + {IA("PreReduceGrad1", OType(0))})); + + if (b_axes.size() > 0) { + HandleBroadcasting(IA("PreReduceGrad1", OType(0)), b, GI(1), b_axes, output); + } else { + output.push_back( + NodeDef("Identity", + {IA("PreReduceGrad1", OType(0))}, + {GI(1)})); + } + } + } else { + //GetShape failed, build shape-independent gradient graph + ArgDef a_axes = IA("ReduceAxes_" + a.name); + ArgDef b_axes = IA("ReduceAxes_" + b.name); + ArgDef A_shape = IA("Shape_" + a.name); + ArgDef B_shape = IA("Shape_" + b.name); + ComputeBroadcastBackwardAxesDynamic(a, b, A_shape, B_shape, &a_axes, &b_axes, output); + + if (IsGradientRequiredForSrcNodeInput(0)) { + output.push_back( + NodeDef("Mul", + {GO(0), I(1)}, + {IA("PreReduceGrad0", OType(0))})); + + HandleBroadcastingDynamic(IA("PreReduceGrad0", OType(0)), a, A_shape, GI(0), a_axes, output); + } + + if (IsGradientRequiredForSrcNodeInput(1)) { + output.push_back( + NodeDef("Mul", + {GO(0), I(0)}, + {IA("PreReduceGrad1", OType(0))})); + + HandleBroadcastingDynamic(IA("PreReduceGrad1", OType(0)), b, B_shape, GI(1), b_axes, output); } } - if (IsGradientRequiredForSrcNodeInput(1)) { - output.push_back( - NodeDef("Mul", - {GO(0), I(0)}, - {IA("PreReduceGrad1", OType(0))})); - - if (b_axes.size() > 0) { - HandleBroadcasting(IA("PreReduceGrad1", OType(0)), b, GI(1), b_axes, output); - } else { - output.push_back( - NodeDef("Identity", - {IA("PreReduceGrad1", OType(0))}, - {GI(1)})); - } - } return output; } @@ -725,17 +827,32 @@ IMPLEMENT_GRADIENT_BUILDER(GetDivGradient) { } else if (IsGradientRequiredForSrcNodeInput(0)) { // Y = A / B, dA = dY / B const ArgDef a = I(0), b = I(1); - std::vector a_axes, b_axes; - ComputeBroadcastBackwardAxes(GetShape(a), GetShape(b), &a_axes, &b_axes); - std::vector output; - ArgDef tmp_grad = IA("PreReduceGrad0", OType(0)); - output.push_back(NodeDef("Div", {GO(0), I(1)}, {tmp_grad})); - if (a_axes.size() > 0) { - HandleBroadcasting(tmp_grad, a, GI(0), a_axes, output); + std::vector a_shape, b_shape; + if (GetShape(a, a_shape).IsOK() && GetShape(b, b_shape).IsOK()) { + std::vector a_axes, b_axes; + ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes); + + ArgDef tmp_grad = IA("PreReduceGrad0", OType(0)); + output.push_back(NodeDef("Div", {GO(0), I(1)}, {tmp_grad})); + if (a_axes.size() > 0) { + HandleBroadcasting(tmp_grad, a, GI(0), a_axes, output); + } else { + output.push_back(NodeDef("Identity", {tmp_grad}, {GI(0)})); + } } else { - output.push_back(NodeDef("Identity", {tmp_grad}, {GI(0)})); + //GetShape failed, build shape-independent gradient graph + ArgDef a_axes = IA("ReduceAxes_" + a.name); + ArgDef A_shape = IA("Shape_" + a.name); + ArgDef B_shape = IA("Shape_" + b.name); + + ComputeBroadcastBackwardAxesDynamic(a, b, A_shape, B_shape, &a_axes, nullptr, output); + + ArgDef tmp_grad = IA("PreReduceGrad0", OType(0)); + output.push_back(NodeDef("Div", {GO(0), I(1)}, {tmp_grad})); + HandleBroadcastingDynamic(tmp_grad, a, A_shape, GI(0), a_axes, output); } + return output; } else if (IsGradientRequiredForSrcNodeInput(1)) { return std::vector{ @@ -907,33 +1024,52 @@ IMPLEMENT_GRADIENT_BUILDER(GetGeluGradient) { namespace { std::vector GetBiasGeluGradNodes( bool use_approximation, - const ArgDef& dY, const ArgDef& X, const ArgDef& B, // inputs - const ArgDef& dX, const ArgDef& dB) { // outputs - const auto B_shape = GetShape(B); - ORT_ENFORCE(B_shape.size() == 1, "B must have exactly one dimension."); + const ArgDef& dY, const ArgDef& X, const ArgDef& B, // inputs + const ArgDef& dX, const ArgDef& dB, // outputs + const ArgDef& b_axes, const ArgDef& b_shape, const ArgDef& x_shape) { //intermediate args + std::vector B_shape, X_shape; + if (GetShape(B, B_shape).IsOK() && GetShape(X, X_shape).IsOK()) { + ORT_ENFORCE(B_shape.size() == 1, "B must have exactly one dimension."); - const std::vector B_axes = [&B_shape, &X]() { - std::vector result{}; - ComputeBroadcastBackwardAxes(B_shape, GetShape(X), &result, nullptr); + const std::vector B_axes = [&B_shape, &X_shape]() { + std::vector result{}; + ComputeBroadcastBackwardAxes(B_shape, X_shape, &result, nullptr); + return result; + }(); + return std::vector{ + NodeDef(OpDef{use_approximation ? "BiasFastGeluGrad_dX" : "BiasGeluGrad_dX", kMSDomain, 1}, + {dY, X, B}, + {dX}), + NodeDef("ReduceSum", + {dX}, + {dB}, + {{"keepdims", MakeAttribute("keepdims", int64_t{0})}, + {"axes", MakeAttribute("axes", B_axes)}})}; + } else { + std::vector result; + ComputeBroadcastBackwardAxesDynamic(B, X, b_shape, x_shape, &b_axes, nullptr, result); + result.push_back( + NodeDef(OpDef{use_approximation ? "BiasFastGeluGrad_dX" : "BiasGeluGrad_dX", kMSDomain, 1}, + {dY, X, B}, + {dX})); + result.push_back( + NodeDef(OpDef{"ReduceSumTraining", kMSDomain, 1}, + {dX, + b_axes}, + {dB}, + {{"keepdims", MakeAttribute("keepdims", int64_t{0})}})); return result; - }(); - - return std::vector{ - NodeDef(OpDef{use_approximation ? "BiasFastGeluGrad_dX" : "BiasGeluGrad_dX", kMSDomain, 1}, - {dY, X, B}, - {dX}), - NodeDef("ReduceSum", - {dX}, - {dB}, - {{"keepdims", MakeAttribute("keepdims", int64_t{0})}, - {"axes", MakeAttribute("axes", B_axes)}})}; + } } } // namespace IMPLEMENT_GRADIENT_BUILDER(GetBiasGeluGradient) { const auto dY = GO(0), X = I(0), B = I(1), dX = GI(0), dB = GI(1); - return GetBiasGeluGradNodes(false, dY, X, B, dX, dB); + ArgDef b_axes = IA("ReduceAxes_" + B.name); + ArgDef b_shape = IA("Shape_" + B.name); + ArgDef x_shape = IA("Shape_" + X.name); + return GetBiasGeluGradNodes(false, dY, X, B, dX, dB, b_axes, b_shape, x_shape); } IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) { @@ -944,7 +1080,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) { // FastGeluGrad doesn't support bias - it needs to be composed with other ops const auto B = I(1), dB = GI(1); - return GetBiasGeluGradNodes(true, dY, X, B, dX, dB); + ArgDef b_axes = IA("ReduceAxes_" + B.name); + ArgDef b_shape = IA("Shape_" + B.name); + ArgDef x_shape = IA("Shape_" + X.name); + return GetBiasGeluGradNodes(true, dY, X, B, dX, dB, b_axes, b_shape, x_shape); } if (num_src_node_inputs == 1) { // without bias return std::vector{ @@ -1070,19 +1209,30 @@ IMPLEMENT_GRADIENT_BUILDER(GetRecvGradient) { IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) { ArgDef a = I(0), y = O(0); - std::vector a_shape = GetShape(a); - std::vector y_shape = GetShape(y); - std::vector a_axes; - ComputeBroadcastBackwardAxes(a_shape, y_shape, &a_axes, nullptr); - std::vector output; - if (a_axes.size() > 0) { - HandleBroadcasting(GO(0), a, GI(0), a_axes, output); + + std::vector a_shape, y_shape; + if (GetShape(a, a_shape).IsOK() && GetShape(y, y_shape).IsOK()) { + std::vector a_axes; + ComputeBroadcastBackwardAxes(a_shape, y_shape, &a_axes, nullptr); + + if (a_axes.size() > 0) { + HandleBroadcasting(GO(0), a, GI(0), a_axes, output); + } else { + output.push_back( + NodeDef("Identity", + {GO(0)}, + {GI(0)})); + } } else { - output.push_back( - NodeDef("Identity", - {GO(0)}, - {GI(0)})); + //GetShape failed, build shape-independent gradient graph + ArgDef a_axes = IA("ReduceAxes_" + a.name); + ArgDef A_shape = IA("Shape_" + a.name); + ArgDef Y_shape = IA("Shape_" + y.name); + + ComputeBroadcastBackwardAxesDynamic(a, y, A_shape, Y_shape, &a_axes, nullptr, output); + + HandleBroadcastingDynamic(GO(0), a, A_shape, GI(0), a_axes, output); } return output; diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.cc b/orttraining/orttraining/core/graph/gradient_builder_base.cc index 7ec4b1256e..293074c562 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_base.cc @@ -85,17 +85,43 @@ void ComputeBroadcastBackwardAxes( } } -std::vector GetShape(const ArgDef& arg_def) { - ORT_ENFORCE(arg_def.type_proto - && arg_def.type_proto->has_tensor_type() - && arg_def.type_proto->tensor_type().has_shape(), - "During GetShape, ", arg_def.name, "'s shape is null."); - std::vector shape; +Status GetShape(const ArgDef& arg_def, std::vector& shape) { + shape.clear(); + ORT_RETURN_IF_NOT(arg_def.type_proto && arg_def.type_proto->has_tensor_type() && arg_def.type_proto->tensor_type().has_shape(), + "During GetShape, ", arg_def.name, "'s shape is null."); const auto& dims = arg_def.type_proto->tensor_type().shape().dim(); for (auto dim = dims.begin(); dim < dims.end(); dim++) { shape.push_back(*dim); } - return shape; + return Status::OK(); +} + +void ComputeBroadcastBackwardAxesDynamic(const ArgDef& a, + const ArgDef& b, + const ArgDef& a_shape, + const ArgDef& b_shape, + const ArgDef* a_axes, + const ArgDef* b_axes, + std::vector& output) { + output.push_back( + NodeDef("Shape", + {a}, + {a_shape})); + + output.push_back( + NodeDef("Shape", + {b}, + {b_shape})); + + ArgDef a_op = ArgDef(""), b_op = ArgDef(""); + if (a_axes) + a_op = *a_axes; + if (b_axes) + b_op = *b_axes; + output.push_back( + NodeDef(OpDef{"BroadcastGradientArgs", kMSDomain, 1}, + {a_shape, b_shape}, + {a_op, b_op})); } void GradientBuilderBase::HandleBroadcasting(const ArgDef& input_grad, @@ -104,9 +130,9 @@ void GradientBuilderBase::HandleBroadcasting(const ArgDef& input_grad, const std::vector& reduce_axes, std::vector& output) const { std::unordered_set reduce_axes_set(reduce_axes.begin(), reduce_axes.end()); - std::vector reduced_shape; - auto input_grad_shape = GetShape(input_grad); - auto target_shape = GetShape(target); + std::vector reduced_shape, input_grad_shape, target_shape; + ORT_ENFORCE(GetShape(input_grad, input_grad_shape).IsOK()); + ORT_ENFORCE(GetShape(target, target_shape).IsOK()); bool keep_dims = (input_grad_shape.size() == target_shape.size()); @@ -165,5 +191,26 @@ void GradientBuilderBase::HandleBroadcasting(const ArgDef& input_grad, } } +void GradientBuilderBase::HandleBroadcastingDynamic(const ArgDef& input_grad, + const ArgDef& target, + const ArgDef& target_shape, + const ArgDef& output_grad, + const ArgDef& reduce_axes, + std::vector& output) const { + ArgDef reduce_grad_arg = IA("ReduceSumTraining_" + input_grad.name + "_for_" + target.name); + output.push_back( + NodeDef(OpDef{"ReduceSumTraining", kMSDomain, 1}, + {input_grad, + reduce_axes}, + {reduce_grad_arg}, + {{"keepdims", ONNX_NAMESPACE::MakeAttribute("keepdims", int64_t(1))}, + {"noop_with_empty_axes", ONNX_NAMESPACE::MakeAttribute("noop_with_empty_axes", int64_t(1))}})); + + output.push_back( + NodeDef("Reshape", + {reduce_grad_arg, target_shape}, + {output_grad})); +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index 90f5ed5124..ebfc758d4a 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -21,7 +21,15 @@ void ComputeBroadcastBackwardAxes( std::vector* A_axes, std::vector* B_axes); -std::vector GetShape(const ArgDef& arg_def); +void ComputeBroadcastBackwardAxesDynamic(const ArgDef& a, + const ArgDef& b, + const ArgDef& a_shape, + const ArgDef& b_shape, + const ArgDef* a_axes, + const ArgDef* b_axes, + std::vector& output); + +Status GetShape(const ArgDef& arg_def, std::vector& shape); typedef std::vector GradientDef; @@ -175,6 +183,13 @@ class GradientBuilderBase { const std::vector& reduce_axes, std::vector& output) const; + void HandleBroadcastingDynamic(const ArgDef& input_grad, + const ArgDef& target, + const ArgDef& target_shape, + const ArgDef& output_grad, + const ArgDef& reduce_axes, + std::vector& output) const; + private: friend class GradientGraphBuilder; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 1d7b0abb98..9fa87ac094 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1128,8 +1128,8 @@ Example 4: } } }); - - ONNX_CONTRIB_OPERATOR_SCHEMA(SplitTraining) + + ONNX_CONTRIB_OPERATOR_SCHEMA(SplitTraining) .SetDomain(kMSDomain) .SinceVersion(1) .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) @@ -1166,7 +1166,7 @@ Example 4: return; } std::vector split = ParseData(split_proto); - + if (!ctx.getInputType(0)->tensor_type().has_shape()) { return; } @@ -1206,7 +1206,7 @@ Example 4: ->mutable_dim(axis) ->set_dim_value(split[i]); } - + }); ONNX_CONTRIB_OPERATOR_SCHEMA(ConcatTraining) @@ -1408,8 +1408,8 @@ Example 4: "Output is an empty vector when no reduction is necessary for the corresponding input.") .Input(0, "a_shape", "The 1st input shape as Tensor.", "T") .Input(1, "b_shape", "The 2nd input shape as Tensor.", "T") - .Output(0, "a_axes", "The reduction axes for 1st input, last to first.", "T") - .Output(1, "b_axes", "The reduction axes for 2nd input, last to first.", "T") + .Output(0, "a_axes", "The reduction axes for 1st input, last to first.", "T", OpSchema::Optional) + .Output(1, "b_axes", "The reduction axes for 2nd input, last to first.", "T", OpSchema::Optional) .TypeConstraint( "T", {"tensor(int64)"}, diff --git a/orttraining/orttraining/test/gradient/gradient_checker.cc b/orttraining/orttraining/test/gradient/gradient_checker.cc index 92892bff21..314b2e9bd5 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.cc +++ b/orttraining/orttraining/test/gradient/gradient_checker.cc @@ -21,7 +21,6 @@ limitations under the License. #include "orttraining/core/graph/gradient_config.h" #include "test/util/include/test_random_seed.h" #include - namespace onnxruntime { namespace test { @@ -112,12 +111,14 @@ inline Status GradientChecker::ComputeTheoreticalJacobianTransp std::vector>* x_datas, std::vector>* y_datas, std::vector>* jacobian_ts, - const std::vector& attributes) { + const std::vector& attributes, + bool add_shape) { size_t y_num = y_infos.size(); size_t x_num = x_infos.size(); // build the graph once and reuse it later in the looping logic GradientOpTester op_session(op_def.type.c_str(), x_infos, y_infos, op_def.opset_version, op_def.domain.c_str(), false); + op_session.AddShapeToTensorData(add_shape); InitOpTesterWithGradGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes); // currently only supported scalar valued fns - and complex types are not supported @@ -287,7 +288,6 @@ inline Status GradientChecker::InitOpTesterWithGradGraph( const std::vector& attributes) { std::unordered_map extra_domain_to_version{{kMSDomain, 1}, {kOnnxDomain, 9}}; InitOpTesterWithGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes, extra_domain_to_version); - // build grad graph auto p_model = op_session.GetModelCache(); auto& graph = p_model->MainGraph(); @@ -328,13 +328,15 @@ inline Status GradientChecker::ComputeNumericJacobianTranspose( std::vector>* x_datas, std::vector>* y_datas, std::vector>* jacobian_ts, - const std::vector& attributes) { + const std::vector& attributes, + bool add_shape) { size_t y_num = y_infos.size(); size_t x_num = x_infos.size(); X_T x_delta = static_cast(delta); // build the graph once and reuse it later in the looping logic OpTester op_session(op_def.type.c_str(), op_def.opset_version, op_def.domain.c_str(), false); + op_session.AddShapeToTensorData(add_shape); InitOpTesterWithGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes); for (int x_idx = 0; x_idx < static_cast(x_num); x_idx++) { @@ -433,11 +435,11 @@ inline Status GradientChecker::ComputeGradientErrorInternal( std::vector>* y_datas, JAC_T* max_error, const std::vector& attributes, - bool check_not_have_gradient) { + bool check_not_have_gradient, + bool check_not_have_shape_inferencing) { // Initialize numeric Jacobian to zeros. std::vector> jacobian_ns; InitJacobians(x_infos, y_infos, &jacobian_ns); - // Compute numeric Jacobian. ORT_RETURN_IF_ERROR(ComputeNumericJacobianTranspose( op_def, x_infos, y_infos, JAC_T{1e-3f}, x_datas, y_datas, &jacobian_ns, attributes)); @@ -445,58 +447,60 @@ inline Status GradientChecker::ComputeGradientErrorInternal( // Compute the maximum error between theoretical and numeric Jacobians. *max_error = 0.0; - // It is necessary to test for inputs with or without gradient. - // We simply set each input without gradient to test the rest inputs' gradient. - // In the last loop it tests for the case where all inputs are with gradient. - size_t total_gradient_variations = check_not_have_gradient ? x_infos.size() + 1 : 1; - for (size_t x_gradient_variation = 0; x_gradient_variation < total_gradient_variations; x_gradient_variation++) { - // Initialize theoretical Jacobians to zeros. - std::vector> jacobian_ts; - InitJacobians(x_infos, y_infos, &jacobian_ts); + int num_grad_builder_checks = check_not_have_shape_inferencing ? 2 : 1; + bool add_shape = true; + for (int i = 0; i < num_grad_builder_checks; i++, add_shape = false) { + // It is necessary to test for inputs with or without gradient. + // We simply set each input without gradient to test the rest inputs' gradient. + // In the last loop it tests for the case where all inputs are with gradient. + size_t total_gradient_variations = check_not_have_gradient ? x_infos.size() + 1 : 1; + for (size_t x_gradient_variation = 0; x_gradient_variation < total_gradient_variations; x_gradient_variation++) { + // Initialize theoretical Jacobians to zeros. + std::vector> jacobian_ts; + InitJacobians(x_infos, y_infos, &jacobian_ts); - std::vector x_infos_gradient_variation = x_infos; + std::vector x_infos_gradient_variation = x_infos; - if (check_not_have_gradient && x_gradient_variation < x_infos.size()) - x_infos_gradient_variation[x_gradient_variation].has_gradient = false; + if (check_not_have_gradient && x_gradient_variation < x_infos.size()) + x_infos_gradient_variation[x_gradient_variation].has_gradient = false; - if (std::all_of(x_infos_gradient_variation.cbegin(), x_infos_gradient_variation.cend(), - [](const TensorInfo& info) { return !info.has_gradient; })) - // a gradient node cannot get created without any has_gradient node. - continue; - - // Compute theoretical Jacobian. - ORT_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose( - op_def, x_infos_gradient_variation, y_infos, x_datas, y_datas, &jacobian_ts, attributes)); - - // We have numeric jacobians regardless of has_gradient (computed once). - // We only have theoretical jacobians for those has_gradient. - // Theoretical jacobians are 0 for those not has_gradient. - int64_t j = 0; - for (auto& x_info : x_infos_gradient_variation) { - if (!x_info.has_gradient) { - // TODO: These 4 test failed at following ORT_ENFORCE. need investigate before enable it. - //GradientCheckerTest.MatMulGrad - //GradientCheckerTest.GemmGrad - //GradientCheckerTest.GatherNDGrad_repeat_float_data - //GradientCheckerTest.GatherNDGrad_unique_float_data - //auto jac_t = jacobian_ts[j]; - //ORT_ENFORCE(std::all_of( - // &jac_t[0], &jac_t[0] + x_info.shape.Size(), [](auto dx) { return dx == 0; })); - j += x_info.shape.Size(); - } else { - for (int r = 0; r < x_info.shape.Size(); j++, r++) { - auto jac_t = jacobian_ts[j]; - auto jac_n = jacobian_ns[j]; - for (size_t i = 0; i < jac_t.size(); i++) { - // dy_i/dx_j for x with gradient. - auto cur_error = std::fabs(jac_t[i] - jac_n[i]); - // Treat any NaN as max_error and immediately return. - // (Note that std::max may ignore NaN arguments.) - if (std::isnan(cur_error)) { - *max_error = cur_error; - return Status::OK(); + if (std::all_of(x_infos_gradient_variation.cbegin(), x_infos_gradient_variation.cend(), + [](const TensorInfo& info) { return !info.has_gradient; })) + // a gradient node cannot get created without any has_gradient node. + continue; + // Compute theoretical Jacobian. + ORT_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose( + op_def, x_infos_gradient_variation, y_infos, x_datas, y_datas, &jacobian_ts, attributes, add_shape)); + // We have numeric jacobians regardless of has_gradient (computed once). + // We only have theoretical jacobians for those has_gradient. + // Theoretical jacobians are 0 for those not has_gradient. + int64_t j = 0; + for (auto& x_info : x_infos_gradient_variation) { + if (!x_info.has_gradient) { + // TODO: These 4 test failed at following ORT_ENFORCE. need investigate before enable it. + //GradientCheckerTest.MatMulGrad + //GradientCheckerTest.GemmGrad + //GradientCheckerTest.GatherNDGrad_repeat_float_data + //GradientCheckerTest.GatherNDGrad_unique_float_data + //auto jac_t = jacobian_ts[j]; + //ORT_ENFORCE(std::all_of( + // &jac_t[0], &jac_t[0] + x_info.shape.Size(), [](auto dx) { return dx == 0; })); + j += x_info.shape.Size(); + } else { + for (int r = 0; r < x_info.shape.Size(); j++, r++) { + auto jac_t = jacobian_ts[j]; + auto jac_n = jacobian_ns[j]; + for (size_t k = 0; k < jac_t.size(); k++) { + // dy_i/dx_j for x with gradient. + auto cur_error = std::fabs(jac_t[k] - jac_n[k]); + // Treat any NaN as max_error and immediately return. + // (Note that std::max may ignore NaN arguments.) + if (std::isnan(cur_error)) { + *max_error = cur_error; + return Status::OK(); + } + *max_error = std::max(*max_error, cur_error); } - *max_error = std::max(*max_error, cur_error); } } } @@ -512,7 +516,8 @@ inline Status GradientChecker::ComputeGradientError( const std::vector& y_infos, JAC_T* max_error, const std::vector& attributes, - bool check_not_have_gradient /* = true*/) { + bool check_not_have_gradient, /* = true*/ + bool check_not_have_shape_inferencing /* = false*/) { // TODO: Consider varying mean and variance float scale = 5.f; float mean = 0.f; @@ -542,7 +547,7 @@ inline Status GradientChecker::ComputeGradientError( // Compute gradient error. return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, - attributes, check_not_have_gradient); + attributes, check_not_have_gradient, check_not_have_shape_inferencing); } template @@ -552,7 +557,9 @@ inline Status GradientChecker::ComputeGradientError( const std::vector& y_infos, JAC_T* max_error, std::vector> x_datas, - const std::vector& attributes) { + const std::vector& attributes, + bool check_not_have_gradient, /* = true*/ + bool check_not_have_shape_inferencing /* = false*/) { // Generate dummy placeholders with zero for y_datas std::vector> y_datas(y_infos.size()); for (size_t i = 0; i < y_infos.size(); i++) { @@ -560,7 +567,8 @@ inline Status GradientChecker::ComputeGradientError( } // Compute gradient error. - return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, attributes); + return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, + attributes, check_not_have_gradient, check_not_have_shape_inferencing); } #define INSTANTIATE_GRAD_ERR_TYPE(X_T, Y_T, JAC_T) \ diff --git a/orttraining/orttraining/test/gradient/gradient_checker.h b/orttraining/orttraining/test/gradient/gradient_checker.h index bd17469228..e37762e8aa 100644 --- a/orttraining/orttraining/test/gradient/gradient_checker.h +++ b/orttraining/orttraining/test/gradient/gradient_checker.h @@ -76,7 +76,9 @@ class GradientChecker { // because the gradient op does not handle the case. We have to use this flag // to disable check for not having gradient cases in order to pass those test. // Remove this flag when the gradient op is fixed. - bool check_not_have_gradient = true); + bool check_not_have_gradient = true, + // Also check gradient builder for op for cases where input shapes are not available + bool check_not_have_shape_inferencing = false); Status ComputeGradientError( const training::OpDef& op_def, @@ -84,7 +86,14 @@ class GradientChecker { const std::vector& y_infos, JAC_T* max_error, std::vector> x_datas, - const std::vector& attributes = {}); + const std::vector& attributes = {}, + // TODO: Ideally it shall check for not has_gradient cases. But some tests are failing + // because the gradient op does not handle the case. We have to use this flag + // to disable check for not having gradient cases in order to pass those test. + // Remove this flag when the gradient op is fixed. + bool check_not_have_gradient = true, + // Also check gradient builder for op for cases where input shapes are not available + bool check_not_have_shape_inferencing = false); private: Status InitJacobians(const std::vector& x_infos, @@ -92,10 +101,10 @@ class GradientChecker { std::vector>* jacobians); std::vector EvaluateFunctionAtInput(OpTester& op_tester, - const std::vector& x_infos, - const std::vector& y_infos, - std::vector>* x_datas, - std::vector>* y_datas); + const std::vector& x_infos, + const std::vector& y_infos, + std::vector>* x_datas, + std::vector>* y_datas); Status InitOpTesterWithGraph(OpTester& op_tester, const std::vector& x_infos, @@ -106,11 +115,11 @@ class GradientChecker { const std::unordered_map& extra_domain_to_version = {}); Status InitOpTesterWithGradGraph(OpTester& op_tester, - const std::vector& x_infos, - const std::vector& y_infos, - std::vector>* x_datas, - std::vector>* y_datas, - const std::vector& attributes); + const std::vector& x_infos, + const std::vector& y_infos, + std::vector>* x_datas, + std::vector>* y_datas, + const std::vector& attributes); Status ComputeTheoreticalJacobianTranspose(const training::OpDef& op_def, const std::vector& x_infos, @@ -118,7 +127,8 @@ class GradientChecker { std::vector>* x_datas, std::vector>* y_datas, std::vector>* jacobian_ts, - const std::vector& attributes); + const std::vector& attributes, + bool add_shape = true); Status ComputeNumericJacobianTranspose(const training::OpDef& op_def, const std::vector& x_infos, @@ -127,7 +137,8 @@ class GradientChecker { std::vector>* x_datas, std::vector>* y_datas, std::vector>* jacobian_ts, - const std::vector& attributes); + const std::vector& attributes, + bool add_shape = true); Status ComputeGradientErrorInternal(const training::OpDef& op_name, const std::vector& x_infos, @@ -136,7 +147,8 @@ class GradientChecker { std::vector>* y_datas, JAC_T* max_error, const std::vector& attributes, - bool check_not_have_gradient = true); + bool check_not_have_gradient = true, + bool check_not_have_shape_inferencing = false); }; } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 9bc844d46c..b5f43f263f 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -101,10 +101,12 @@ TEST(GradientCheckerTest, SqrtGrad) { } void TestBroadcastableBinaryOpGrad(const std::string& op_type, - std::function* transformer = nullptr) { + std::function* transformer = nullptr, + bool check_not_have_shape_inferencing = true) { float max_error; GradientChecker gradient_checker; OpDef op_def{op_type}; + const std::vector attributes = {}; //shape(A) = (2, 3, 4, 5), shape(B) = (2, 3, 4, 5), ==> shape(result) = (2, 3, 4, 5) { @@ -112,7 +114,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{2, 3, 4, 5}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -122,7 +125,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -132,7 +136,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{2, 3, 4, 5}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -142,7 +147,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{5}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -152,7 +158,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{2, 3, 4, 5}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -162,7 +169,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{2, 3, 1, 1}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -172,7 +180,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{2, 1, 1, 1}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -182,7 +191,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{1, 3, 4, 1}, true, transformer}; TensorInfo Y_info{{2, 3, 4, 5}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } @@ -193,7 +203,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{4, 2, 1, 1}, true, transformer, DataTypeImpl::GetTensorType(), {"4", "2", "1", "1"}}; TensorInfo Y_info{{4, 2, 1, 3}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } // symbolic broadcast + numeric broadcast @@ -203,7 +214,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type, TensorInfo B_info{{4, 1, 1, 3}, true, transformer, DataTypeImpl::GetTensorType(), {"batch", "1", "1", "seq"}}; TensorInfo Y_info{{4, 2, 3, 3}}; - gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error); + gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error, + attributes, true, check_not_have_shape_inferencing); EXPECT_IS_TINY(max_error); } } @@ -259,52 +271,61 @@ TEST(GradientCheckerTest, MatMulGrad) { const float error_tolerance = 1e-1f; GradientChecker gradient_checker; OpDef op_def{"MatMul"}; + const std::vector attributes = {}; // 2D x 2D { - gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}}, {{2, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}}, {{2, 3}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // 3D x 3D { - gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // 3D x 2D { - gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {4, 3}}, {{2, 3, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {4, 3}}, {{2, 3, 3}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // 2D x 3D { - gradient_checker.ComputeGradientError(op_def, {{3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // 4D x 4D { - gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {2, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {2, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // 4D x 2D { - gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {5, 4}}, {{2, 3, 4, 4}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {5, 4}}, {{2, 3, 4, 4}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // 4D x 3D { - gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {3, 5, 4}}, {{2, 3, 4, 4}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {3, 5, 4}}, {{2, 3, 4, 4}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // 4D x 4D with broadcast { - gradient_checker.ComputeGradientError(op_def, {{2, 1, 4, 5}, {1, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 1, 4, 5}, {1, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error, + attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } } @@ -324,54 +345,55 @@ TEST(GradientCheckerTest, GemmGrad) { const float error_tolerance = 2e-2f; GradientChecker gradient_checker; OpDef op_def{"Gemm"}; + const std::vector attributes = {}; // Single Batch with Scalar Bias { - gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {}}, {{1, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {}}, {{1, 3}}, &max_error, attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // Single Batch with Vector Bias { - gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {3}}, {{1, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {3}}, {{1, 3}}, &max_error, attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // Non-Single Batch with Scalar Bias { - gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {}}, {{2, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {}}, {{2, 3}}, &max_error, attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // Non-Single Batch with Vector Bias { - gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {3}}, {{2, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {3}}, {{2, 3}}, &max_error, attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // Non-Single Batch with Broadcast Bias { - gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {1, 3}}, {{2, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {1, 3}}, {{2, 3}}, &max_error, attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // Non-Single Batch with Non-BroadcastBias { - gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {2, 3}}, {{2, 3}}, &max_error); + gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {2, 3}}, {{2, 3}}, &max_error, attributes, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // TransA { gradient_checker.ComputeGradientError(op_def, {{4, 2}, {4, 3}, {3}}, {{2, 3}}, &max_error, - {MakeAttribute("transA", int64_t(1))}); + {MakeAttribute("transA", int64_t(1))}, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } // TransB { gradient_checker.ComputeGradientError(op_def, {{2, 4}, {3, 4}, {3}}, {{2, 3}}, &max_error, - {MakeAttribute("transB", int64_t(1))}); + {MakeAttribute("transB", int64_t(1))}, true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -379,7 +401,8 @@ TEST(GradientCheckerTest, GemmGrad) { { gradient_checker.ComputeGradientError(op_def, {{4, 2}, {3, 4}, {3}}, {{2, 3}}, &max_error, {MakeAttribute("transA", int64_t(1)), - MakeAttribute("transB", int64_t(1))}); + MakeAttribute("transB", int64_t(1))}, + true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -387,7 +410,8 @@ TEST(GradientCheckerTest, GemmGrad) { { gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {2, 3}}, {{2, 3}}, &max_error, {MakeAttribute("alpha", 0.7f), - MakeAttribute("beta", 5.0f)}); + MakeAttribute("beta", 5.0f)}, + true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -395,7 +419,8 @@ TEST(GradientCheckerTest, GemmGrad) { { gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {3}}, {{2, 3}}, &max_error, {MakeAttribute("alpha", 0.7f), - MakeAttribute("beta", 5.0f)}); + MakeAttribute("beta", 5.0f)}, + true, true); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } } @@ -847,7 +872,9 @@ TEST(GradientCheckerTest, TransposeGrad) { { TensorShape x_shape({2, 3, 4}); TensorShape y_shape({4, 3, 2}); - gradient_checker.ComputeGradientError(op_def, {x_shape}, {y_shape}, &max_error); + const std::vector attributes = {}; + gradient_checker.ComputeGradientError(op_def, {x_shape}, {y_shape}, &max_error, + attributes, true, true /*also test w/o shape inferencing */); EXPECT_IS_TINIER_THAN(max_error, error_tolerance); } @@ -1377,10 +1404,12 @@ void TestBiasGeluGrad(const std::string& op_type, const std::string& domain, int GradientChecker gradient_checker; OpDef op_def{op_type, domain, opset_version}; + const std::vector attributes = {}; float max_error; ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( - op_def, {input_shape, bias_shape}, {input_shape}, &max_error)); + op_def, {input_shape, bias_shape}, {input_shape}, &max_error, + attributes, true, true)); EXPECT_IS_TINY(max_error); } @@ -1852,6 +1881,7 @@ TEST(GradientCheckerTest, ExpandGrad) { float max_error; GradientChecker gradient_checker; OpDef op_def{"Expand"}; + const std::vector attributes = {}; //input_shape = (2, 3, 1), target_shape = (2, 3, 4) ==> shape(result) = (2, 3, 4) { @@ -1861,7 +1891,7 @@ TEST(GradientCheckerTest, ExpandGrad) { TensorInfo y_info({2, 3, 4}, true); - gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true); EXPECT_IS_TINY(max_error); } @@ -1873,7 +1903,7 @@ TEST(GradientCheckerTest, ExpandGrad) { TensorInfo y_info({2, 3, 4}, true); - gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true); EXPECT_IS_TINY(max_error); } @@ -1885,7 +1915,7 @@ TEST(GradientCheckerTest, ExpandGrad) { TensorInfo y_info({2, 3, 4}, true); - gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true); EXPECT_IS_TINY(max_error); } @@ -1897,7 +1927,7 @@ TEST(GradientCheckerTest, ExpandGrad) { TensorInfo y_info({2, 3, 1}, true); - gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true); EXPECT_IS_TINY(max_error); } @@ -1909,7 +1939,7 @@ TEST(GradientCheckerTest, ExpandGrad) { TensorInfo y_info({4, 5, 2, 3}, true); - gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true); EXPECT_IS_TINY(max_error); } @@ -1921,7 +1951,7 @@ TEST(GradientCheckerTest, ExpandGrad) { TensorInfo y_info({4, 5, 2, 3}, true); - gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas); + gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true); EXPECT_IS_TINY(max_error); } } diff --git a/orttraining/orttraining/test/training_ops/cpu/nn/broadcast_grad_args_test.cc b/orttraining/orttraining/test/training_ops/cpu/nn/broadcast_grad_args_test.cc index a10c919b10..6a4d3ed3f5 100644 --- a/orttraining/orttraining/test/training_ops/cpu/nn/broadcast_grad_args_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/nn/broadcast_grad_args_test.cc @@ -24,16 +24,18 @@ constexpr auto k_opset_version = 1; void RunBroadcastGradientArgsTest(const char* op, const std::vector& A_shape_tensor, const std::vector& B_shape_tensor, - const std::vector& A_axes_expected, - const std::vector& B_axes_expected, + const std::vector* A_axes_expected, + const std::vector* B_axes_expected, bool fail = false) { OpTester t{op, k_opset_version, kMSDomain}; t.AddInput("a_shape", {static_cast(A_shape_tensor.size())}, A_shape_tensor); t.AddInput("b_shape", {static_cast(B_shape_tensor.size())}, B_shape_tensor); - t.AddOutput("a_axes", {static_cast(A_axes_expected.size())}, A_axes_expected); - t.AddOutput("b_axes", {static_cast(B_axes_expected.size())}, B_axes_expected); + if (A_axes_expected) + t.AddOutput("a_axes", {static_cast(A_axes_expected->size())}, *A_axes_expected); + if (B_axes_expected) + t.AddOutput("b_axes", {static_cast(B_axes_expected->size())}, *B_axes_expected); std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); @@ -48,48 +50,72 @@ void RunBroadcastGradientArgsTest(const char* op, // BroadcastGradientArgs TEST(BroadcastGradientArgsTest, Basic) { + std::vector A_axes_expected = {}; + std::vector B_axes_expected = {1, 0}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 16, 1024, 1024}, {1, 1, 1024, 1024}, - {}, {1, 0}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_both_valid_op) { + std::vector A_axes_expected = {2}; + std::vector B_axes_expected = {1, 0}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 16, 1, 1024}, {1, 1, 1024, 1024}, - {2}, {1, 0}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_no_bcast) { + std::vector A_axes_expected = {}; + std::vector B_axes_expected = {}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 3, 4, 5}, {2, 3, 4, 5}, - {}, {}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_B_scalar) { + std::vector A_axes_expected = {}; + std::vector B_axes_expected = {3, 2, 1, 0}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 3, 4, 5}, {}, - {}, {3, 2, 1, 0}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_B_vector) { + std::vector A_axes_expected = {}; + std::vector B_axes_expected = {2, 1, 0}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 3, 4, 5}, {5}, - {}, {2, 1, 0}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_A_bcast_different_size) { + std::vector A_axes_expected = {1, 0}; + std::vector B_axes_expected = {}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {4, 5}, {2, 3, 4, 5}, - {1, 0}, {}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_both_bcast_different_size) { + std::vector A_axes_expected = {1, 0}; + std::vector B_axes_expected = {3, 2}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {1, 4, 5}, {2, 3, 1, 1}, - {1, 0}, {3, 2}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_both_bcast_different_size_2) { + std::vector A_axes_expected = {0}; + std::vector B_axes_expected = {3, 2, 1}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {3, 4, 5}, {2, 1, 1, 1}, - {0}, {3, 2, 1}); + &A_axes_expected, &B_axes_expected); } TEST(BroadcastGradientArgsTest, Basic_invalid_broadcast) { + std::vector A_axes_expected = {}; + std::vector B_axes_expected = {}; RunBroadcastGradientArgsTest("BroadcastGradientArgs", {3, 4, 5}, {2, 1, 6, 1}, - {}, {}, true /*fail*/); + &A_axes_expected, &B_axes_expected, true /*fail*/); +} + +TEST(BroadcastGradientArgsTest, Basic_only_A_output) { + std::vector A_axes_expected = {0}; + RunBroadcastGradientArgsTest("BroadcastGradientArgs", {3, 4, 5}, {2, 1, 1, 1}, + &A_axes_expected, nullptr); } } // namespace test diff --git a/orttraining/orttraining/training_ops/cpu/nn/broadcast_grad_args.cc b/orttraining/orttraining/training_ops/cpu/nn/broadcast_grad_args.cc index aea6b7832d..b44fdd9103 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/broadcast_grad_args.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/broadcast_grad_args.cc @@ -68,11 +68,18 @@ Status BroadcastGradientArgs::Compute(OpKernelContext* context) const { } Tensor* A_axes = context->Output(0, {static_cast(a_axes.size())}); - T* A_axes_data = A_axes->template MutableData(); - std::copy(a_axes.begin(), a_axes.end(), A_axes_data); + if (A_axes) { //verify as A_axes is an optional output + T* A_axes_data = A_axes->template MutableData(); + std::copy(a_axes.begin(), a_axes.end(), A_axes_data); + } + Tensor* B_axes = context->Output(1, {static_cast(b_axes.size())}); - T* B_axes_data = B_axes->template MutableData(); - std::copy(b_axes.begin(), b_axes.end(), B_axes_data); + if (B_axes) { //verify as B_axes is an optional output + T* B_axes_data = B_axes->template MutableData(); + std::copy(b_axes.begin(), b_axes.end(), B_axes_data); + } + if (!A_axes && !B_axes) + LOGS_DEFAULT(WARNING) << "No output found for op BroadcastGradientArgs."; return Status::OK(); }