mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
MatMul Gradient optimization for dB when B's is 2D tensor (#4899)
* Optimized MatMulGrad for dB when B's shape is 2D * Refactor for ConstantScalarNode Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
6dc85b5f14
commit
9f5d4918dc
2 changed files with 83 additions and 82 deletions
|
|
@ -115,7 +115,44 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
|
|||
|
||||
ArgDef A = I(0), B = I(1), Y = O(0);
|
||||
std::vector<Dimension> A_shape, B_shape, Y_shape;
|
||||
if (GetShape(A, A_shape).IsOK() && GetShape(B, B_shape).IsOK() && GetShape(Y, Y_shape).IsOK()) {
|
||||
const bool A_has_shape = GetShape(A, A_shape).IsOK();
|
||||
const bool B_has_shape = GetShape(B, B_shape).IsOK();
|
||||
const bool Y_has_shape = GetShape(Y, Y_shape).IsOK();
|
||||
|
||||
auto dB_2d_case = [&]() {
|
||||
NodeDef zero_int64_const_node = ConstantScalarNode(int64_t{0}, {1}, Name("zero_int64"));
|
||||
NodeDef one_const_node = ConstantScalarNode(int64_t{1}, {1}, Name("one"));
|
||||
NodeDef neg_one_const_node = ConstantScalarNode(int64_t{-1}, {1}, Name("neg_one"));
|
||||
NodeDef zero_float_const_node = ConstantScalarNode(float{0.0f}, {1}, Name("zero_float"));
|
||||
|
||||
ArgDef ZERO_I = zero_int64_const_node.output_args[0];
|
||||
ArgDef ONE = one_const_node.output_args[0];
|
||||
ArgDef NEG_ONE = neg_one_const_node.output_args[0];
|
||||
ArgDef ZERO_F = zero_float_const_node.output_args[0];
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
zero_int64_const_node,
|
||||
one_const_node,
|
||||
neg_one_const_node,
|
||||
zero_float_const_node,
|
||||
|
||||
NodeDef("Shape", {B}, {IA("B_shape")}),
|
||||
|
||||
// reshape A to 2D [M, K]
|
||||
NodeDef("Gather", {IA("B_shape"), ZERO_I}, {IA("K_dim")}, {MakeAttribute("axis", int64_t(0))}),
|
||||
NodeDef("Concat", {NEG_ONE, IA("K_dim")}, {IA("A_target_shape")}, {MakeAttribute("axis", int64_t(0))}),
|
||||
NodeDef("Reshape", {A, IA("A_target_shape")}, {IA("A_reshape_2d")}),
|
||||
|
||||
// reshape dY to 2D [M, N]
|
||||
NodeDef("Gather", {IA("B_shape"), ONE}, {IA("N_dim")}, {MakeAttribute("axis", int64_t(0))}),
|
||||
NodeDef("Concat", {NEG_ONE, IA("N_dim")}, {IA("dY_target_shape")}, {MakeAttribute("axis", int64_t(0))}),
|
||||
NodeDef("Reshape", {GO(0), IA("dY_target_shape")}, {IA("dY_reshape_2d")}),
|
||||
|
||||
// dB = A' * dY
|
||||
NodeDef("Gemm", {IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO_F}, {GI(1)}, {MakeAttribute("transA", int64_t(1))})};
|
||||
};
|
||||
|
||||
if (A_has_shape && B_has_shape && Y_has_shape) {
|
||||
std::vector<AttributeProto> shared_attributes;
|
||||
shared_attributes.push_back(MakeAttribute("beta", float(0)));
|
||||
AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1));
|
||||
|
|
@ -202,53 +239,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
|
|||
}
|
||||
}
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
if (B_shape.size() == 2 &&
|
||||
(B_shape[0].has_dim_value() || A_shape[A_shape.size() - 1].has_dim_value()) &&
|
||||
(B_shape[1].has_dim_value() || Y_shape[Y_shape.size() - 1].has_dim_value())) {
|
||||
// A[M, K], B[K, N], Y[M, N]
|
||||
int64_t K, N;
|
||||
if (B_shape[0].has_dim_value()) {
|
||||
K = B_shape[0].dim_value();
|
||||
} else {
|
||||
K = A_shape[A_shape.size() - 1].dim_value();
|
||||
}
|
||||
if (B_shape[1].has_dim_value()) {
|
||||
N = B_shape[1].dim_value();
|
||||
} else {
|
||||
N = Y_shape[Y_shape.size() - 1].dim_value();
|
||||
}
|
||||
|
||||
std::vector<int64_t> A_shape_2d{-1, K};
|
||||
NodeDef A_shape_2d_node = ConstantValueNode(A_shape_2d, Name("A_shape_2d"));
|
||||
ArgDef A_shape_2d_arg = A_shape_2d_node.output_args[0];
|
||||
result.push_back(A_shape_2d_node);
|
||||
|
||||
std::vector<int64_t> dY_shape_2d{-1, N};
|
||||
NodeDef dY_shape_2d_node = ConstantValueNode(dY_shape_2d, Name("dY_shape_2d"));
|
||||
ArgDef dY_shape_2d_arg = dY_shape_2d_node.output_args[0];
|
||||
result.push_back(dY_shape_2d_node);
|
||||
|
||||
NodeDef zero_constant_node = ZeroConstantNode();
|
||||
ArgDef ZERO = zero_constant_node.output_args[0];
|
||||
result.push_back(zero_constant_node);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{A, A_shape_2d_arg},
|
||||
{IA("A_reshape_2d")}));
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{GO(0), dY_shape_2d_arg},
|
||||
{IA("dY_reshape_2d")}));
|
||||
|
||||
// dB = A' * dY
|
||||
std::vector<AttributeProto> attrs(shared_attributes);
|
||||
attrs.push_back(transpose_first_input);
|
||||
result.push_back(
|
||||
NodeDef("Gemm",
|
||||
{IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO},
|
||||
{GI(1)},
|
||||
attrs));
|
||||
if (B_shape.size() == 2) {
|
||||
// for case: A[M1, M2, ... , K], B[K, N], Y[M1, M2, ..., N]
|
||||
const std::vector<NodeDef> dB_subgraph = dB_2d_case();
|
||||
result.insert(result.end(), dB_subgraph.begin(), dB_subgraph.end());
|
||||
} else {
|
||||
int64_t A_rank = A_shape.size();
|
||||
std::vector<int64_t> A_perm(A_rank);
|
||||
|
|
@ -317,17 +311,23 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
|
|||
HandleBroadcastingDynamic(pre_reduce_grad_0, A, a_shape, GI(0), a_axes, result);
|
||||
}
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1");
|
||||
result.push_back(
|
||||
NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1},
|
||||
{A, GO(0)},
|
||||
{pre_reduce_grad_1},
|
||||
{{"transA", MakeAttribute("transA", int64_t(1))}}));
|
||||
if (B_has_shape && B_shape.size() == 2) {
|
||||
// for case: A[M1, M2, ... , K], B[K, N], Y[M1, M2, ..., N]
|
||||
const std::vector<NodeDef> dB_subgraph = dB_2d_case();
|
||||
result.insert(result.end(), dB_subgraph.begin(), dB_subgraph.end());
|
||||
} else {
|
||||
ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1");
|
||||
result.push_back(
|
||||
NodeDef(OpDef{"TransposeScaleMatMul", kMSDomain, 1},
|
||||
{A, GO(0)},
|
||||
{pre_reduce_grad_1},
|
||||
{{"transA", MakeAttribute("transA", int64_t(1))}}));
|
||||
|
||||
b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name);
|
||||
ia_shape = IA("Shape_" + pre_reduce_grad_1.name);
|
||||
ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result);
|
||||
HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result);
|
||||
b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name);
|
||||
ia_shape = IA("Shape_" + pre_reduce_grad_1.name);
|
||||
ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result);
|
||||
HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -436,7 +436,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
HandleBroadcasting(dY, C, IA("dC_reduced"), C_axes, result);
|
||||
|
||||
if (has_beta && beta != 1.0f) {
|
||||
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
|
||||
NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale"));
|
||||
ArgDef SCALE = scale_node.output_args[0];
|
||||
result.push_back(scale_node);
|
||||
result.push_back(
|
||||
|
|
@ -449,7 +449,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
}
|
||||
} else {
|
||||
if (has_beta && beta != 1.0f) {
|
||||
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
|
||||
NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale"));
|
||||
ArgDef SCALE = scale_node.output_args[0];
|
||||
result.push_back(scale_node);
|
||||
result.push_back(
|
||||
|
|
@ -474,7 +474,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
HandleBroadcastingDynamic(dY, C, c_shape, IA("dC_reduced"), c_axes, result);
|
||||
|
||||
if (has_beta && beta != 1.0f) {
|
||||
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
|
||||
NodeDef scale_node = ConstantScalarNode(beta, {1}, Name("Scale"));
|
||||
ArgDef SCALE = scale_node.output_args[0];
|
||||
result.push_back(scale_node);
|
||||
result.push_back(
|
||||
|
|
@ -1167,7 +1167,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalAveragePoolGradient) {
|
|||
}
|
||||
}
|
||||
|
||||
NodeDef scale_node = ConstantValueNode(1.0f / static_cast<float>(scale), Name("Scale"));
|
||||
NodeDef scale_node = ConstantScalarNode(1.0f / static_cast<float>(scale), {1}, Name("Scale"));
|
||||
ArgDef SCALE = scale_node.output_args[0];
|
||||
return std::vector<NodeDef>{
|
||||
scale_node,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "orttraining/core/graph/recompute_graph_utils.h"
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
|
@ -159,12 +160,24 @@ class GradientBuilderBase {
|
|||
return node_->OpType();
|
||||
}
|
||||
|
||||
static NodeDef ConstantValueNode(const std::vector<int64_t>& values, const std::string& arg_name) {
|
||||
ONNX_NAMESPACE::TensorProto t_proto;
|
||||
template <typename T>
|
||||
static NodeDef ConstantVectorNode(const std::vector<T>& values, const std::string& arg_name) {
|
||||
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(values);
|
||||
t_proto.add_dims(values.size());
|
||||
t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
for (auto value : values) {
|
||||
t_proto.add_int64_data(value);
|
||||
|
||||
return NodeDef("Constant",
|
||||
{},
|
||||
{ArgDef(arg_name, nullptr)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static NodeDef ConstantScalarNode(T value, std::vector<int64_t> shape, const std::string& arg_name) {
|
||||
ORT_ENFORCE(shape.size() == 0 || (shape.size() == 1 && shape[0] == 1));
|
||||
|
||||
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(value);
|
||||
for (auto dim : shape) {
|
||||
t_proto.add_dims(dim);
|
||||
}
|
||||
|
||||
return NodeDef("Constant",
|
||||
|
|
@ -173,24 +186,12 @@ class GradientBuilderBase {
|
|||
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
|
||||
}
|
||||
|
||||
static NodeDef ConstantValueNode(float value, const std::string& arg_name) {
|
||||
ONNX_NAMESPACE::TensorProto t_proto;
|
||||
t_proto.add_dims(1);
|
||||
t_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
t_proto.add_float_data(value);
|
||||
|
||||
return NodeDef("Constant",
|
||||
{},
|
||||
{ArgDef(arg_name, nullptr)},
|
||||
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
|
||||
}
|
||||
|
||||
static NodeDef ZeroConstantNode() {
|
||||
return ConstantValueNode(0.0f, "ZeroConstant");
|
||||
return ConstantScalarNode(0.0f, {1}, "ZeroConstant");
|
||||
}
|
||||
|
||||
static NodeDef OneConstantNode() {
|
||||
return ConstantValueNode(1.0f, "OneConstant");
|
||||
return ConstantScalarNode(1.0f, {1}, "OneConstant");
|
||||
}
|
||||
|
||||
void HandleBroadcasting(const ArgDef& input_grad,
|
||||
|
|
@ -232,7 +233,7 @@ class GradientBuilderBase {
|
|||
|
||||
// contains set of input arg names of node_ which requires gradient
|
||||
std::unordered_set<std::string> gradient_outputs_;
|
||||
|
||||
|
||||
const logging::Logger& logger_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue