MatMul Gradient optimization for dB when B's is 2D tensor (#4899)

* Optimized MatMulGrad for dB when B's shape is 2D

* Refactor for ConstantScalarNode

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Sherlock 2020-08-27 11:33:20 -07:00 committed by GitHub
parent 6dc85b5f14
commit 9f5d4918dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 82 deletions

View file

@ -115,7 +115,44 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
ArgDef A = I(0), B = I(1), Y = O(0);
std::vector<Dimension> A_shape, B_shape, Y_shape;
if (GetShape(A, A_shape).IsOK() && GetShape(B, B_shape).IsOK() && GetShape(Y, Y_shape).IsOK()) {
const bool A_has_shape = GetShape(A, A_shape).IsOK();
const bool B_has_shape = GetShape(B, B_shape).IsOK();
const bool Y_has_shape = GetShape(Y, Y_shape).IsOK();
auto dB_2d_case = [&]() {
NodeDef zero_int64_const_node = ConstantScalarNode(int64_t{0}, {1}, Name("zero_int64"));
NodeDef one_const_node = ConstantScalarNode(int64_t{1}, {1}, Name("one"));
NodeDef neg_one_const_node = ConstantScalarNode(int64_t{-1}, {1}, Name("neg_one"));
NodeDef zero_float_const_node = ConstantScalarNode(float{0.0f}, {1}, Name("zero_float"));
ArgDef ZERO_I = zero_int64_const_node.output_args[0];
ArgDef ONE = one_const_node.output_args[0];
ArgDef NEG_ONE = neg_one_const_node.output_args[0];
ArgDef ZERO_F = zero_float_const_node.output_args[0];
return std::vector<NodeDef>{
zero_int64_const_node,
one_const_node,
neg_one_const_node,
zero_float_const_node,
NodeDef("Shape", {B}, {IA("B_shape")}),
// reshape A to 2D [M, K]
NodeDef("Gather", {IA("B_shape"), ZERO_I}, {IA("K_dim")}, {MakeAttribute("axis", int64_t(0))}),
NodeDef("Concat", {NEG_ONE, IA("K_dim")}, {IA("A_target_shape")}, {MakeAttribute("axis", int64_t(0))}),
NodeDef("Reshape", {A, IA("A_target_shape")}, {IA("A_reshape_2d")}),
// reshape dY to 2D [M, N]
NodeDef("Gather", {IA("B_shape"), ONE}, {IA("N_dim")}, {MakeAttribute("axis", int64_t(0))}),
NodeDef("Concat", {NEG_ONE, IA("N_dim")}, {IA("dY_target_shape")}, {MakeAttribute("axis", int64_t(0))}),
NodeDef("Reshape", {GO(0), IA("dY_target_shape")}, {IA("dY_reshape_2d")}),
// dB = A' * dY
NodeDef("Gemm", {IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO_F}, {GI(1)}, {MakeAttribute("transA", int64_t(1))})};
};
if (A_has_shape && B_has_shape && Y_has_shape) {
std::vector<AttributeProto> shared_attributes;
shared_attributes.push_back(MakeAttribute("beta", float(0)));
AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1));
@ -202,53 +239,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
}
}
if (IsGradientRequiredForSrcNodeInput(1)) {
if (B_shape.size() == 2 &&
(B_shape[0].has_dim_value() || A_shape[A_shape.size() - 1].has_dim_value()) &&
(B_shape[1].has_dim_value() || Y_shape[Y_shape.size() - 1].has_dim_value())) {
// A[M, K], B[K, N], Y[M, N]
int64_t K, N;
if (B_shape[0].has_dim_value()) {
K = B_shape[0].dim_value();
} else {
K = A_shape[A_shape.size() - 1].dim_value();
}
if (B_shape[1].has_dim_value()) {
N = B_shape[1].dim_value();
} else {
N = Y_shape[Y_shape.size() - 1].dim_value();
}
std::vector<int64_t> A_shape_2d{-1, K};
NodeDef A_shape_2d_node = ConstantValueNode(A_shape_2d, Name("A_shape_2d"));
ArgDef A_shape_2d_arg = A_shape_2d_node.output_args[0];
result.push_back(A_shape_2d_node);
std::vector<int64_t> dY_shape_2d{-1, N};
NodeDef dY_shape_2d_node = ConstantValueNode(dY_shape_2d, Name("dY_shape_2d"));
ArgDef dY_shape_2d_arg = dY_shape_2d_node.output_args[0];
result.push_back(dY_shape_2d_node);
NodeDef zero_constant_node = ZeroConstantNode();
ArgDef ZERO = zero_constant_node.output_args[0];
result.push_back(zero_constant_node);
result.push_back(
NodeDef("Reshape",
{A, A_shape_2d_arg},
{IA("A_reshape_2d")}));
result.push_back(
NodeDef("Reshape",
{GO(0), dY_shape_2d_arg},
{IA("dY_reshape_2d")}));
// dB = A' * dY
std::vector<AttributeProto> attrs(shared_attributes);
attrs.push_back(transpose_first_input);
result.push_back(
NodeDef("Gemm",
{IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO},
{GI(1)},
attrs));
if (B_shape.size() == 2) {
// for case: A[M1, M2, ... , K], B[K, N], Y[M1, M2, ..., N]
const std::vector<NodeDef> dB_subgraph = dB_2d_case();
result.insert(result.end(), dB_subgraph.begin(), dB_subgraph.end());
} else {
int64_t A_rank = A_shape.size();
std::vector<int64_t> A_perm(A_rank);
@ -317,17 +311,23 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
HandleBroadcastingDynamic(pre_reduce_grad_0, A, a_shape, GI(0), a_axes, result);
}
if (IsGradientRequiredForSrcNodeInput(1)) {
ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1");
result.push_back(
NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1},
{A, GO(0)},
{pre_reduce_grad_1},
{{"transA", MakeAttribute("transA", int64_t(1))}}));
if (B_has_shape && B_shape.size() == 2) {
// for case: A[M1, M2, ... , K], B[K, N], Y[M1, M2, ..., N]
const std::vector<NodeDef> dB_subgraph = dB_2d_case();
result.insert(result.end(), dB_subgraph.begin(), dB_subgraph.end());
} else {
ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1");
result.push_back(
NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1},
{A, GO(0)},
{pre_reduce_grad_1},
{{"transA", MakeAttribute("transA", int64_t(1))}}));
b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name);
ia_shape = IA("Shape_" + pre_reduce_grad_1.name);
ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result);
HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result);
b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name);
ia_shape = IA("Shape_" + pre_reduce_grad_1.name);
ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result);
HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result);
}
}
}
@ -436,7 +436,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
HandleBroadcasting(dY, C, IA("dC_reduced"), C_axes, result);
if (has_beta && beta != 1.0f) {
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale"));
ArgDef SCALE = scale_node.output_args[0];
result.push_back(scale_node);
result.push_back(
@ -449,7 +449,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
}
} else {
if (has_beta && beta != 1.0f) {
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale"));
ArgDef SCALE = scale_node.output_args[0];
result.push_back(scale_node);
result.push_back(
@ -474,7 +474,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
HandleBroadcastingDynamic(dY, C, c_shape, IA("dC_reduced"), c_axes, result);
if (has_beta && beta != 1.0f) {
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale"));
ArgDef SCALE = scale_node.output_args[0];
result.push_back(scale_node);
result.push_back(
@ -1167,7 +1167,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalAveragePoolGradient) {
}
}
NodeDef scale_node = ConstantValueNode(1.0f / static_cast<float>(scale), Name("Scale"));
NodeDef scale_node = ConstantScalarNode(1.0f / static_cast<float>(scale), {1}, Name("Scale"));
ArgDef SCALE = scale_node.output_args[0];
return std::vector<NodeDef>{
scale_node,

View file

@ -10,6 +10,7 @@
#include "orttraining/core/graph/gradient_config.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "onnx/defs/attr_proto_util.h"
#include "onnx/defs/tensor_proto_util.h"
namespace onnxruntime {
namespace training {
@ -159,12 +160,24 @@ class GradientBuilderBase {
return node_->OpType();
}
static NodeDef ConstantValueNode(const std::vector<int64_t>& values, const std::string& arg_name) {
ONNX_NAMESPACE::TensorProto t_proto;
template <typename T>
static NodeDef ConstantVectorNode(const std::vector<T>& values, const std::string& arg_name) {
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(values);
t_proto.add_dims(values.size());
t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
for (auto value : values) {
t_proto.add_int64_data(value);
return NodeDef("Constant",
{},
{ArgDef(arg_name, nullptr)},
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
}
template <typename T>
static NodeDef ConstantScalarNode(T value, std::vector<int64_t> shape, const std::string& arg_name) {
ORT_ENFORCE(shape.size() == 0 || (shape.size() == 1 && shape[0] == 1));
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(value);
for (auto dim : shape) {
t_proto.add_dims(dim);
}
return NodeDef("Constant",
@ -173,24 +186,12 @@ class GradientBuilderBase {
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
}
static NodeDef ConstantValueNode(float value, const std::string& arg_name) {
ONNX_NAMESPACE::TensorProto t_proto;
t_proto.add_dims(1);
t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
t_proto.add_float_data(value);
return NodeDef("Constant",
{},
{ArgDef(arg_name, nullptr)},
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
}
static NodeDef ZeroConstantNode() {
return ConstantValueNode(0.0f, "ZeroConstant");
return ConstantScalarNode(0.0f, {1}, "ZeroConstant");
}
static NodeDef OneConstantNode() {
return ConstantValueNode(1.0f, "OneConstant");
return ConstantScalarNode(1.0f, {1}, "OneConstant");
}
void HandleBroadcasting(const ArgDef& input_grad,
@ -232,7 +233,7 @@ class GradientBuilderBase {
// contains set of input arg names of node_ which requires gradient
std::unordered_set<std::string> gradient_outputs_;
const logging::Logger& logger_;
};