From 9f5d4918dc3dc8d2e69ac2e8c369993cdcd4aac3 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Thu, 27 Aug 2020 11:33:20 -0700 Subject: [PATCH] MatMul Gradient optimization for dB when B's is 2D tensor (#4899) * Optimized MatMulGrad for dB when B's shape is 2D * Refactor for ConstantScalarNode Co-authored-by: Sherlock Huang --- .../core/graph/gradient_builder.cc | 124 +++++++++--------- .../core/graph/gradient_builder_base.h | 41 +++--- 2 files changed, 83 insertions(+), 82 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 73a1096da4..bb02758847 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -115,7 +115,44 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) { ArgDef A = I(0), B = I(1), Y = O(0); std::vector A_shape, B_shape, Y_shape; - if (GetShape(A, A_shape).IsOK() && GetShape(B, B_shape).IsOK() && GetShape(Y, Y_shape).IsOK()) { + const bool A_has_shape = GetShape(A, A_shape).IsOK(); + const bool B_has_shape = GetShape(B, B_shape).IsOK(); + const bool Y_has_shape = GetShape(Y, Y_shape).IsOK(); + + auto dB_2d_case = [&]() { + NodeDef zero_int64_const_node = ConstantScalarNode(int64_t{0}, {1}, Name("zero_int64")); + NodeDef one_const_node = ConstantScalarNode(int64_t{1}, {1}, Name("one")); + NodeDef neg_one_const_node = ConstantScalarNode(int64_t{-1}, {1}, Name("neg_one")); + NodeDef zero_float_const_node = ConstantScalarNode(float{0.0f}, {1}, Name("zero_float")); + + ArgDef ZERO_I = zero_int64_const_node.output_args[0]; + ArgDef ONE = one_const_node.output_args[0]; + ArgDef NEG_ONE = neg_one_const_node.output_args[0]; + ArgDef ZERO_F = zero_float_const_node.output_args[0]; + + return std::vector{ + zero_int64_const_node, + one_const_node, + neg_one_const_node, + zero_float_const_node, + + NodeDef("Shape", {B}, {IA("B_shape")}), + + // reshape A to 2D [M, K] + NodeDef("Gather", {IA("B_shape"), ZERO_I}, {IA("K_dim")}, {MakeAttribute("axis", int64_t(0))}), + NodeDef("Concat", {NEG_ONE, IA("K_dim")}, {IA("A_target_shape")}, {MakeAttribute("axis", int64_t(0))}), + NodeDef("Reshape", {A, IA("A_target_shape")}, {IA("A_reshape_2d")}), + + // reshape dY to 2D [M, N] + NodeDef("Gather", {IA("B_shape"), ONE}, {IA("N_dim")}, {MakeAttribute("axis", int64_t(0))}), + NodeDef("Concat", {NEG_ONE, IA("N_dim")}, {IA("dY_target_shape")}, {MakeAttribute("axis", int64_t(0))}), + NodeDef("Reshape", {GO(0), IA("dY_target_shape")}, {IA("dY_reshape_2d")}), + + // dB = A' * dY + NodeDef("Gemm", {IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO_F}, {GI(1)}, {MakeAttribute("transA", int64_t(1))})}; + }; + + if (A_has_shape && B_has_shape && Y_has_shape) { std::vector shared_attributes; shared_attributes.push_back(MakeAttribute("beta", float(0))); AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1)); @@ -202,53 +239,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) { } } if (IsGradientRequiredForSrcNodeInput(1)) { - if (B_shape.size() == 2 && - (B_shape[0].has_dim_value() || A_shape[A_shape.size() - 1].has_dim_value()) && - (B_shape[1].has_dim_value() || Y_shape[Y_shape.size() - 1].has_dim_value())) { - // A[M, K], B[K, N], Y[M, N] - int64_t K, N; - if (B_shape[0].has_dim_value()) { - K = B_shape[0].dim_value(); - } else { - K = A_shape[A_shape.size() - 1].dim_value(); - } - if (B_shape[1].has_dim_value()) { - N = B_shape[1].dim_value(); - } else { - N = Y_shape[Y_shape.size() - 1].dim_value(); - } - - std::vector A_shape_2d{-1, K}; - NodeDef A_shape_2d_node = ConstantValueNode(A_shape_2d, Name("A_shape_2d")); - ArgDef A_shape_2d_arg = A_shape_2d_node.output_args[0]; - result.push_back(A_shape_2d_node); - - std::vector dY_shape_2d{-1, N}; - NodeDef dY_shape_2d_node = ConstantValueNode(dY_shape_2d, Name("dY_shape_2d")); - ArgDef dY_shape_2d_arg = dY_shape_2d_node.output_args[0]; - result.push_back(dY_shape_2d_node); - - NodeDef zero_constant_node = ZeroConstantNode(); - ArgDef ZERO = zero_constant_node.output_args[0]; - result.push_back(zero_constant_node); - - result.push_back( - NodeDef("Reshape", - {A, A_shape_2d_arg}, - {IA("A_reshape_2d")})); - result.push_back( - NodeDef("Reshape", - {GO(0), dY_shape_2d_arg}, - {IA("dY_reshape_2d")})); - - // dB = A' * dY - std::vector attrs(shared_attributes); - attrs.push_back(transpose_first_input); - result.push_back( - NodeDef("Gemm", - {IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO}, - {GI(1)}, - attrs)); + if (B_shape.size() == 2) { + // for case: A[M1, M2, ... , K], B[K, N], Y[M1, M2, ..., N] + const std::vector dB_subgraph = dB_2d_case(); + result.insert(result.end(), dB_subgraph.begin(), dB_subgraph.end()); } else { int64_t A_rank = A_shape.size(); std::vector A_perm(A_rank); @@ -317,17 +311,23 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) { HandleBroadcastingDynamic(pre_reduce_grad_0, A, a_shape, GI(0), a_axes, result); } if (IsGradientRequiredForSrcNodeInput(1)) { - ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1"); - result.push_back( - NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1}, - {A, GO(0)}, - {pre_reduce_grad_1}, - {{"transA", MakeAttribute("transA", int64_t(1))}})); + if (B_has_shape && B_shape.size() == 2) { + // for case: A[M1, M2, ... , K], B[K, N], Y[M1, M2, ..., N] + const std::vector dB_subgraph = dB_2d_case(); + result.insert(result.end(), dB_subgraph.begin(), dB_subgraph.end()); + } else { + ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1"); + result.push_back( + NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1}, + {A, GO(0)}, + {pre_reduce_grad_1}, + {{"transA", MakeAttribute("transA", int64_t(1))}})); - b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name); - ia_shape = IA("Shape_" + pre_reduce_grad_1.name); - ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result); - HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result); + b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name); + ia_shape = IA("Shape_" + pre_reduce_grad_1.name); + ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result); + HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result); + } } } @@ -436,7 +436,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { HandleBroadcasting(dY, C, IA("dC_reduced"), C_axes, result); if (has_beta && beta != 1.0f) { - NodeDef scale_node = ConstantValueNode(beta, Name("Scale")); + NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale")); ArgDef SCALE = scale_node.output_args[0]; result.push_back(scale_node); result.push_back( @@ -449,7 +449,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { } } else { if (has_beta && beta != 1.0f) { - NodeDef scale_node = ConstantValueNode(beta, Name("Scale")); + NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale")); ArgDef SCALE = scale_node.output_args[0]; result.push_back(scale_node); result.push_back( @@ -474,7 +474,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { HandleBroadcastingDynamic(dY, C, c_shape, IA("dC_reduced"), c_axes, result); if (has_beta && beta != 1.0f) { - NodeDef scale_node = ConstantValueNode(beta, Name("Scale")); + NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale")); ArgDef SCALE = scale_node.output_args[0]; result.push_back(scale_node); result.push_back( @@ -1167,7 +1167,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalAveragePoolGradient) { } } - NodeDef scale_node = ConstantValueNode(1.0f / static_cast(scale), Name("Scale")); + NodeDef scale_node = ConstantScalarNode(1.0f / static_cast(scale), {1}, Name("Scale")); ArgDef SCALE = scale_node.output_args[0]; return std::vector{ scale_node, diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index dcaf7818bc..7a8f3a5007 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -10,6 +10,7 @@ #include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/graph/recompute_graph_utils.h" #include "onnx/defs/attr_proto_util.h" +#include "onnx/defs/tensor_proto_util.h" namespace onnxruntime { namespace training { @@ -159,12 +160,24 @@ class GradientBuilderBase { return node_->OpType(); } - static NodeDef ConstantValueNode(const std::vector& values, const std::string& arg_name) { - ONNX_NAMESPACE::TensorProto t_proto; + template + static NodeDef ConstantVectorNode(const std::vector& values, const std::string& arg_name) { + auto t_proto = ONNX_NAMESPACE::ToTensor(values); t_proto.add_dims(values.size()); - t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - for (auto value : values) { - t_proto.add_int64_data(value); + + return NodeDef("Constant", + {}, + {ArgDef(arg_name, nullptr)}, + {ONNX_NAMESPACE::MakeAttribute("value", t_proto)}); + } + + template + static NodeDef ConstantScalarNode(T value, std::vector shape, const std::string& arg_name) { + ORT_ENFORCE(shape.size() == 0 || (shape.size() == 1 && shape[0] == 1)); + + auto t_proto = ONNX_NAMESPACE::ToTensor(value); + for (auto dim : shape) { + t_proto.add_dims(dim); } return NodeDef("Constant", @@ -173,24 +186,12 @@ class GradientBuilderBase { {ONNX_NAMESPACE::MakeAttribute("value", t_proto)}); } - static NodeDef ConstantValueNode(float value, const std::string& arg_name) { - ONNX_NAMESPACE::TensorProto t_proto; - t_proto.add_dims(1); - t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - t_proto.add_float_data(value); - - return NodeDef("Constant", - {}, - {ArgDef(arg_name, nullptr)}, - {ONNX_NAMESPACE::MakeAttribute("value", t_proto)}); - } - static NodeDef ZeroConstantNode() { - return ConstantValueNode(0.0f, "ZeroConstant"); + return ConstantScalarNode(0.0f, {1}, "ZeroConstant"); } static NodeDef OneConstantNode() { - return ConstantValueNode(1.0f, "OneConstant"); + return ConstantScalarNode(1.0f, {1}, "OneConstant"); } void HandleBroadcasting(const ArgDef& input_grad, @@ -232,7 +233,7 @@ class GradientBuilderBase { // contains set of input arg names of node_ which requires gradient std::unordered_set gradient_outputs_; - + const logging::Logger& logger_; };