mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Shape independent gradient builder for ops requiring broadcast (#4586)
* Adding CPU implementation of BroadcastGradientArgs op
Modify to take shape as input instead of tensor
Cleanup
Correct schema
Corrected kernel, added tests, addressed review comments.
Initial change, to add ReduceSumTraining cpu op
cpu support
Initial changes to gradient builder
Non-empty reduction case passing.
Added exception,test for invalid broadcast,addresed review comments.
Initial change, to add ReduceSumTraining cpu op
cpu support
cuda support + more UTs
on comments + UT
no op support for {} axes with new attr - noop_with_empty_axes
Add noop attribute to ReduceSumTraining use
Add testing for no-shape graph, modify AddSub grad builder, logging.:
MulGrad support
Div support
Expand support
Gemm support
MatMul grad change
Transpose Grad change
BiasGeluGrad change.
Fixes after squash
* Remove logging, add specific exception for shape inference error
* fix build
* Review comments
* Review comments
* Fix windows build
Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
parent
948a33bdfc
commit
d4983f83ff
12 changed files with 721 additions and 427 deletions
|
|
@ -60,7 +60,6 @@ namespace cuda {
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
|
||||
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
|
||||
template <bool allow_multi_axes>
|
||||
template <typename T, typename OutT, cudnnReduceTensorIndices_t ReduceTensorIndices>
|
||||
|
|
|
|||
|
|
@ -598,7 +598,7 @@ TEST(ReductionOpTest, ReduceMax_int32) {
|
|||
#if defined(OPENVINO_CONFIG_GPU_FP32)
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Disabled temporarily
|
||||
#else
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: axis must be 0
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: axis must be 0
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -404,7 +404,7 @@ class OpTester {
|
|||
template <typename T>
|
||||
void AddAttribute(std::string name, T value) {
|
||||
// Generate a the proper AddAttribute call for later
|
||||
add_attribute_funcs_.emplace_back([name = std::move(name), value = std::move(value)](onnxruntime::Node& node) {
|
||||
add_attribute_funcs_.emplace_back([ name = std::move(name), value = std::move(value) ](onnxruntime::Node & node) {
|
||||
node.AddAttribute(name, value);
|
||||
});
|
||||
}
|
||||
|
|
@ -541,7 +541,7 @@ class OpTester {
|
|||
value.Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
|
||||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
||||
auto node_arg = NodeArg(name, &type_proto.proto);
|
||||
if (dim_params && !(dim_params->empty())) {
|
||||
if (dim_params && !(dim_params->empty()) && add_shape_to_tensor_data_) {
|
||||
// If dim_params presents, configure node_arg's dim value based on dim_params, which supports symbolic dim and dim broadcast.
|
||||
auto& dim_params_data = *dim_params;
|
||||
onnx::TensorShapeProto new_shape;
|
||||
|
|
|
|||
|
|
@ -62,192 +62,221 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
|
|||
std::vector<NodeDef> result;
|
||||
|
||||
ArgDef A = I(0), B = I(1), Y = O(0);
|
||||
std::vector<Dimension> A_shape = GetShape(A);
|
||||
std::vector<Dimension> B_shape = GetShape(B);
|
||||
std::vector<Dimension> Y_shape = GetShape(Y);
|
||||
std::vector<Dimension> A_shape, B_shape, Y_shape;
|
||||
if (GetShape(A, A_shape).IsOK() && GetShape(B, B_shape).IsOK() && GetShape(Y, Y_shape).IsOK()) {
|
||||
std::vector<AttributeProto> shared_attributes;
|
||||
shared_attributes.push_back(MakeAttribute("beta", float(0)));
|
||||
AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1));
|
||||
AttributeProto transpose_second_input = MakeAttribute("transB", int64_t(1));
|
||||
|
||||
std::vector<AttributeProto> shared_attributes;
|
||||
shared_attributes.push_back(MakeAttribute("beta", float(0)));
|
||||
AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1));
|
||||
AttributeProto transpose_second_input = MakeAttribute("transB", int64_t(1));
|
||||
if (A_shape.size() == 2 && B_shape.size() == 2) {
|
||||
NodeDef zero_constant_node = ZeroConstantNode();
|
||||
ArgDef ZERO = zero_constant_node.output_args[0];
|
||||
result.push_back(zero_constant_node);
|
||||
|
||||
if (A_shape.size() == 2 && B_shape.size() == 2) {
|
||||
NodeDef zero_constant_node = ZeroConstantNode();
|
||||
ArgDef ZERO = zero_constant_node.output_args[0];
|
||||
result.push_back(zero_constant_node);
|
||||
|
||||
// is GI(0) required
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
// dA = dY * B'
|
||||
std::vector<AttributeProto> attrs(shared_attributes);
|
||||
attrs.push_back(transpose_second_input);
|
||||
result.push_back(
|
||||
NodeDef("Gemm",
|
||||
{GO(0), B, ZERO},
|
||||
{GI(0)},
|
||||
attrs));
|
||||
}
|
||||
|
||||
// is GI(1) required
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
// dB = A' * dY
|
||||
std::vector<AttributeProto> attrs(shared_attributes);
|
||||
attrs.push_back(transpose_first_input);
|
||||
result.push_back(
|
||||
NodeDef("Gemm",
|
||||
{A, GO(0), ZERO},
|
||||
{GI(1)},
|
||||
attrs));
|
||||
}
|
||||
} else if (A_shape.size() > 2 || B_shape.size() > 2) {
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
// If B_shape.size() == 2, dA is computed through 2 ops: transpose and matmul.
|
||||
// It can be replaced with Gemm(dY_reshape, B_transpose) and reshape.
|
||||
// However, there is a performance degradation.
|
||||
// Thus this implementation is not implemented.
|
||||
int64_t B_rank = B_shape.size();
|
||||
std::vector<int64_t> B_perm(B_rank);
|
||||
std::iota(B_perm.begin(), B_perm.end(), 0);
|
||||
std::swap(B_perm[B_rank - 1], B_perm[B_rank - 2]);
|
||||
|
||||
std::vector<Dimension> output_shape;
|
||||
for (size_t i = 0; i < Y_shape.size() - 1; i++) {
|
||||
output_shape.push_back(Y_shape[i]);
|
||||
// is GI(0) required
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
// dA = dY * B'
|
||||
std::vector<AttributeProto> attrs(shared_attributes);
|
||||
attrs.push_back(transpose_second_input);
|
||||
result.push_back(
|
||||
NodeDef("Gemm",
|
||||
{GO(0), B, ZERO},
|
||||
{GI(0)},
|
||||
attrs));
|
||||
}
|
||||
output_shape.push_back(B_shape[B_shape.size() - 2]);
|
||||
|
||||
std::vector<int64_t> A_axes;
|
||||
ComputeBroadcastBackwardAxes(A_shape, output_shape, &A_axes, nullptr);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Transpose",
|
||||
{B},
|
||||
{IA("B_t")},
|
||||
{MakeAttribute("perm", B_perm)}));
|
||||
|
||||
ArgDef matmul_out = A_axes.size() > 0 ? IA("PreReduceGrad0") : GI(0);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("MatMul",
|
||||
{GO(0), IA("B_t")},
|
||||
{matmul_out}));
|
||||
|
||||
if (A_axes.size() > 0) {
|
||||
result.push_back(
|
||||
NodeDef("ReduceSum",
|
||||
{IA("PreReduceGrad0")},
|
||||
{IA("ReduceGrad0")},
|
||||
{{"keepdims", MakeAttribute("keepdims", int64_t(1))},
|
||||
{"axes", MakeAttribute("axes", A_axes)}}));
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Shape",
|
||||
{A},
|
||||
{IA("A_shape")}));
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{IA("ReduceGrad0"), IA("A_shape")},
|
||||
{GI(0)}));
|
||||
}
|
||||
}
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
if (B_shape.size() == 2 &&
|
||||
(B_shape[0].has_dim_value() || A_shape[A_shape.size() - 1].has_dim_value()) &&
|
||||
(B_shape[1].has_dim_value() || Y_shape[Y_shape.size() - 1].has_dim_value())) {
|
||||
// A[M, K], B[K, N], Y[M, N]
|
||||
int64_t K, N;
|
||||
if (B_shape[0].has_dim_value()) {
|
||||
K = B_shape[0].dim_value();
|
||||
} else {
|
||||
K = A_shape[A_shape.size() - 1].dim_value();
|
||||
}
|
||||
if (B_shape[1].has_dim_value()) {
|
||||
N = B_shape[1].dim_value();
|
||||
} else {
|
||||
N = Y_shape[Y_shape.size() - 1].dim_value();
|
||||
}
|
||||
|
||||
std::vector<int64_t> A_shape_2d{-1, K};
|
||||
NodeDef A_shape_2d_node = ConstantValueNode(A_shape_2d, Name("A_shape_2d"));
|
||||
ArgDef A_shape_2d_arg = A_shape_2d_node.output_args[0];
|
||||
result.push_back(A_shape_2d_node);
|
||||
|
||||
std::vector<int64_t> dY_shape_2d{-1, N};
|
||||
NodeDef dY_shape_2d_node = ConstantValueNode(dY_shape_2d, Name("dY_shape_2d"));
|
||||
ArgDef dY_shape_2d_arg = dY_shape_2d_node.output_args[0];
|
||||
result.push_back(dY_shape_2d_node);
|
||||
|
||||
NodeDef zero_constant_node = ZeroConstantNode();
|
||||
ArgDef ZERO = zero_constant_node.output_args[0];
|
||||
result.push_back(zero_constant_node);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{A, A_shape_2d_arg},
|
||||
{IA("A_reshape_2d")}));
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{GO(0), dY_shape_2d_arg},
|
||||
{IA("dY_reshape_2d")}));
|
||||
|
||||
// is GI(1) required
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
// dB = A' * dY
|
||||
std::vector<AttributeProto> attrs(shared_attributes);
|
||||
attrs.push_back(transpose_first_input);
|
||||
result.push_back(
|
||||
NodeDef("Gemm",
|
||||
{IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO},
|
||||
{A, GO(0), ZERO},
|
||||
{GI(1)},
|
||||
attrs));
|
||||
} else {
|
||||
int64_t A_rank = A_shape.size();
|
||||
std::vector<int64_t> A_perm(A_rank);
|
||||
std::iota(A_perm.begin(), A_perm.end(), 0);
|
||||
std::swap(A_perm[A_rank - 1], A_perm[A_rank - 2]);
|
||||
}
|
||||
} else if (A_shape.size() > 2 || B_shape.size() > 2) {
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
// If B_shape.size() == 2, dA is computed through 2 ops: transpose and matmul.
|
||||
// It can be replaced with Gemm(dY_reshape, B_transpose) and reshape.
|
||||
// However, there is a performance degradation.
|
||||
// Thus this implementation is not implemented.
|
||||
int64_t B_rank = B_shape.size();
|
||||
std::vector<int64_t> B_perm(B_rank);
|
||||
std::iota(B_perm.begin(), B_perm.end(), 0);
|
||||
std::swap(B_perm[B_rank - 1], B_perm[B_rank - 2]);
|
||||
|
||||
std::vector<Dimension> output_shape;
|
||||
for (size_t i = 0; i < Y_shape.size() - 2; i++) {
|
||||
for (size_t i = 0; i < Y_shape.size() - 1; i++) {
|
||||
output_shape.push_back(Y_shape[i]);
|
||||
}
|
||||
output_shape.push_back(A_shape[A_shape.size() - 1]);
|
||||
output_shape.push_back(Y_shape[Y_shape.size() - 1]);
|
||||
output_shape.push_back(B_shape[B_shape.size() - 2]);
|
||||
|
||||
std::vector<int64_t> B_axes;
|
||||
ComputeBroadcastBackwardAxes(B_shape, output_shape, &B_axes, nullptr);
|
||||
std::vector<int64_t> A_axes;
|
||||
ComputeBroadcastBackwardAxes(A_shape, output_shape, &A_axes, nullptr);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Transpose",
|
||||
{A},
|
||||
{IA("A_t")},
|
||||
{MakeAttribute("perm", A_perm)}));
|
||||
{B},
|
||||
{IA("B_t")},
|
||||
{MakeAttribute("perm", B_perm)}));
|
||||
|
||||
ArgDef matmul_out = B_axes.size() > 0 ? IA("PreReduceGrad1") : GI(1);
|
||||
ArgDef matmul_out = A_axes.size() > 0 ? IA("PreReduceGrad0") : GI(0);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("MatMul",
|
||||
{IA("A_t"), GO(0)},
|
||||
{GO(0), IA("B_t")},
|
||||
{matmul_out}));
|
||||
|
||||
if (B_axes.size() > 0) {
|
||||
if (A_axes.size() > 0) {
|
||||
result.push_back(
|
||||
NodeDef("ReduceSum",
|
||||
{IA("PreReduceGrad1")},
|
||||
{IA("ReduceGrad1")},
|
||||
{{"keepdims", MakeAttribute("keepdims", int64_t(0))},
|
||||
{"axes", MakeAttribute("axes", B_axes)}}));
|
||||
{IA("PreReduceGrad0")},
|
||||
{IA("ReduceGrad0")},
|
||||
{{"keepdims", MakeAttribute("keepdims", int64_t(1))},
|
||||
{"axes", MakeAttribute("axes", A_axes)}}));
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Shape",
|
||||
{B},
|
||||
{IA("B_shape")}));
|
||||
{A},
|
||||
{IA("A_shape")}));
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{IA("ReduceGrad1"), IA("B_shape")},
|
||||
{GI(1)}));
|
||||
{IA("ReduceGrad0"), IA("A_shape")},
|
||||
{GI(0)}));
|
||||
}
|
||||
}
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
if (B_shape.size() == 2 &&
|
||||
(B_shape[0].has_dim_value() || A_shape[A_shape.size() - 1].has_dim_value()) &&
|
||||
(B_shape[1].has_dim_value() || Y_shape[Y_shape.size() - 1].has_dim_value())) {
|
||||
// A[M, K], B[K, N], Y[M, N]
|
||||
int64_t K, N;
|
||||
if (B_shape[0].has_dim_value()) {
|
||||
K = B_shape[0].dim_value();
|
||||
} else {
|
||||
K = A_shape[A_shape.size() - 1].dim_value();
|
||||
}
|
||||
if (B_shape[1].has_dim_value()) {
|
||||
N = B_shape[1].dim_value();
|
||||
} else {
|
||||
N = Y_shape[Y_shape.size() - 1].dim_value();
|
||||
}
|
||||
|
||||
std::vector<int64_t> A_shape_2d{-1, K};
|
||||
NodeDef A_shape_2d_node = ConstantValueNode(A_shape_2d, Name("A_shape_2d"));
|
||||
ArgDef A_shape_2d_arg = A_shape_2d_node.output_args[0];
|
||||
result.push_back(A_shape_2d_node);
|
||||
|
||||
std::vector<int64_t> dY_shape_2d{-1, N};
|
||||
NodeDef dY_shape_2d_node = ConstantValueNode(dY_shape_2d, Name("dY_shape_2d"));
|
||||
ArgDef dY_shape_2d_arg = dY_shape_2d_node.output_args[0];
|
||||
result.push_back(dY_shape_2d_node);
|
||||
|
||||
NodeDef zero_constant_node = ZeroConstantNode();
|
||||
ArgDef ZERO = zero_constant_node.output_args[0];
|
||||
result.push_back(zero_constant_node);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{A, A_shape_2d_arg},
|
||||
{IA("A_reshape_2d")}));
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{GO(0), dY_shape_2d_arg},
|
||||
{IA("dY_reshape_2d")}));
|
||||
|
||||
// dB = A' * dY
|
||||
std::vector<AttributeProto> attrs(shared_attributes);
|
||||
attrs.push_back(transpose_first_input);
|
||||
result.push_back(
|
||||
NodeDef("Gemm",
|
||||
{IA("A_reshape_2d"), IA("dY_reshape_2d"), ZERO},
|
||||
{GI(1)},
|
||||
attrs));
|
||||
} else {
|
||||
int64_t A_rank = A_shape.size();
|
||||
std::vector<int64_t> A_perm(A_rank);
|
||||
std::iota(A_perm.begin(), A_perm.end(), 0);
|
||||
std::swap(A_perm[A_rank - 1], A_perm[A_rank - 2]);
|
||||
|
||||
std::vector<Dimension> output_shape;
|
||||
for (size_t i = 0; i < Y_shape.size() - 2; i++) {
|
||||
output_shape.push_back(Y_shape[i]);
|
||||
}
|
||||
output_shape.push_back(A_shape[A_shape.size() - 1]);
|
||||
output_shape.push_back(Y_shape[Y_shape.size() - 1]);
|
||||
|
||||
std::vector<int64_t> B_axes;
|
||||
ComputeBroadcastBackwardAxes(B_shape, output_shape, &B_axes, nullptr);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("Transpose",
|
||||
{A},
|
||||
{IA("A_t")},
|
||||
{MakeAttribute("perm", A_perm)}));
|
||||
|
||||
ArgDef matmul_out = B_axes.size() > 0 ? IA("PreReduceGrad1") : GI(1);
|
||||
|
||||
result.push_back(
|
||||
NodeDef("MatMul",
|
||||
{IA("A_t"), GO(0)},
|
||||
{matmul_out}));
|
||||
|
||||
if (B_axes.size() > 0) {
|
||||
result.push_back(
|
||||
NodeDef("ReduceSum",
|
||||
{IA("PreReduceGrad1")},
|
||||
{IA("ReduceGrad1")},
|
||||
{{"keepdims", MakeAttribute("keepdims", int64_t(0))},
|
||||
{"axes", MakeAttribute("axes", B_axes)}}));
|
||||
result.push_back(
|
||||
NodeDef("Shape",
|
||||
{B},
|
||||
{IA("B_shape")}));
|
||||
result.push_back(
|
||||
NodeDef("Reshape",
|
||||
{IA("ReduceGrad1"), IA("B_shape")},
|
||||
{GI(1)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ORT_THROW("Matmul Gradient Builder shouldn't reach here. ");
|
||||
//GetShape failed, build shape-independent gradient graph
|
||||
ArgDef a_axes, b_axes, a_shape, b_shape, ia_shape;
|
||||
a_shape = IA("Shape_" + A.name);
|
||||
b_shape = IA("Shape_" + B.name);
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
ArgDef pre_reduce_grad_0 = IA("PreReduceGrad0");
|
||||
result.push_back(
|
||||
NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1},
|
||||
{GO(0), B},
|
||||
{pre_reduce_grad_0},
|
||||
{{"transB", MakeAttribute("transB", int64_t(1))}}));
|
||||
|
||||
a_axes = IA("ReduceAxes_" + A.name + "_for_" + A.name);
|
||||
ia_shape = IA("Shape_" + pre_reduce_grad_0.name);
|
||||
ComputeBroadcastBackwardAxesDynamic(A, pre_reduce_grad_0, a_shape, ia_shape, &a_axes, nullptr, result);
|
||||
HandleBroadcastingDynamic(pre_reduce_grad_0, A, a_shape, GI(0), a_axes, result);
|
||||
}
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
ArgDef pre_reduce_grad_1 = IA("PreReduceGrad1");
|
||||
result.push_back(
|
||||
NodeDef(OpDef{"TransposeMatMul", kMSDomain, 1},
|
||||
{A, GO(0)},
|
||||
{pre_reduce_grad_1},
|
||||
{{"transA", MakeAttribute("transA", int64_t(1))}}));
|
||||
|
||||
b_axes = IA("ReduceAxes_" + B.name + "_for_" + B.name);
|
||||
ia_shape = IA("Shape_" + pre_reduce_grad_1.name);
|
||||
ComputeBroadcastBackwardAxesDynamic(pre_reduce_grad_1, B, ia_shape, b_shape, nullptr, &b_axes, result);
|
||||
HandleBroadcastingDynamic(pre_reduce_grad_1, B, b_shape, GI(1), b_axes, result);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
@ -346,15 +375,51 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
bool has_beta = attributes.at("beta").has_f();
|
||||
float beta = attributes.at("beta").f();
|
||||
ORT_ENFORCE(beta != 0.0f);
|
||||
std::vector<Dimension> C_shape, dY_shape;
|
||||
if (GetShape(C, C_shape).IsOK() && GetShape(dY, dY_shape).IsOK()) {
|
||||
std::vector<int64_t> C_axes, dY_axes;
|
||||
ComputeBroadcastBackwardAxes(C_shape, dY_shape, &C_axes, &dY_axes);
|
||||
|
||||
std::vector<Dimension> C_shape = GetShape(C);
|
||||
std::vector<Dimension> dY_shape = GetShape(dY);
|
||||
if (C_axes.size() > 0) {
|
||||
HandleBroadcasting(dY, C, IA("dC_reduced"), C_axes, result);
|
||||
|
||||
std::vector<int64_t> C_axes, dY_axes;
|
||||
ComputeBroadcastBackwardAxes(C_shape, dY_shape, &C_axes, &dY_axes);
|
||||
if (has_beta && beta != 1.0f) {
|
||||
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
|
||||
ArgDef SCALE = scale_node.output_args[0];
|
||||
result.push_back(scale_node);
|
||||
result.push_back(
|
||||
NodeDef("Mul",
|
||||
{IA("dC_reduced"), SCALE},
|
||||
{dC}));
|
||||
} else {
|
||||
result.push_back(
|
||||
NodeDef("Identity", {IA("dC_reduced")}, {dC}));
|
||||
}
|
||||
} else {
|
||||
if (has_beta && beta != 1.0f) {
|
||||
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
|
||||
ArgDef SCALE = scale_node.output_args[0];
|
||||
result.push_back(scale_node);
|
||||
result.push_back(
|
||||
NodeDef("Mul",
|
||||
{dY, SCALE},
|
||||
{dC}));
|
||||
} else {
|
||||
result.push_back(
|
||||
NodeDef("Identity",
|
||||
{dY},
|
||||
{dC}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//GetShape failed, build shape-independent gradient graph
|
||||
ArgDef c_axes = IA("ReduceAxes_" + C.name);
|
||||
ArgDef c_shape = IA("Shape_" + C.name);
|
||||
ArgDef dy_shape = IA("Shape_" + dY.name);
|
||||
|
||||
if (C_axes.size() > 0) {
|
||||
HandleBroadcasting(dY, C, IA("dC_reduced"), C_axes, result);
|
||||
ComputeBroadcastBackwardAxesDynamic(C, dY, c_shape, dy_shape, &c_axes, nullptr, result);
|
||||
|
||||
HandleBroadcastingDynamic(dY, C, c_shape, IA("dC_reduced"), c_axes, result);
|
||||
|
||||
if (has_beta && beta != 1.0f) {
|
||||
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
|
||||
|
|
@ -368,21 +433,6 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
result.push_back(
|
||||
NodeDef("Identity", {IA("dC_reduced")}, {dC}));
|
||||
}
|
||||
} else {
|
||||
if (has_beta && beta != 1.0f) {
|
||||
NodeDef scale_node = ConstantValueNode(beta, Name("Scale"));
|
||||
ArgDef SCALE = scale_node.output_args[0];
|
||||
result.push_back(scale_node);
|
||||
result.push_back(
|
||||
NodeDef("Mul",
|
||||
{dY, SCALE},
|
||||
{dC}));
|
||||
} else {
|
||||
result.push_back(
|
||||
NodeDef("Identity",
|
||||
{dY},
|
||||
{dC}));
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
@ -419,7 +469,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetConcatGradient) {
|
|||
std::vector<int64_t> split_attribute(GetSrcNodeInputSize());
|
||||
std::vector<ArgDef> outputs;
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); ++i) {
|
||||
std::vector<Dimension> data_shape = GetShape(I(i));
|
||||
std::vector<Dimension> data_shape;
|
||||
ORT_ENFORCE(GetShape(I(i), data_shape).IsOK());
|
||||
int64_t axis_index = axis < 0 ? static_cast<int64_t>(data_shape.size()) + axis : axis;
|
||||
if (axis_index >= 0 && axis_index < static_cast<int64_t>(data_shape.size()) && data_shape[axis_index].has_dim_value()) {
|
||||
split_attribute[i] = data_shape[axis_index].dim_value();
|
||||
|
|
@ -464,10 +515,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetReshapeGradient) {
|
|||
IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) {
|
||||
std::vector<int64_t> bw_perm;
|
||||
auto attributes = SrcNodeAttributes();
|
||||
std::vector<AttributeProto> new_attributes;
|
||||
if (attributes.empty()) {
|
||||
const TensorShapeProto& input_shape = I(0).type_proto->tensor_type().shape();
|
||||
for (int i = input_shape.dim_size() - 1; i >= 0; --i) {
|
||||
bw_perm.push_back(i);
|
||||
if (input_shape.dim_size() > 0) { //input_shape is available
|
||||
int n = input_shape.dim_size() - 1;
|
||||
bw_perm.resize(n + 1);
|
||||
std::generate(bw_perm.begin(), bw_perm.end(), [&n] { return n--; });
|
||||
new_attributes.push_back(MakeAttribute("perm", bw_perm));
|
||||
}
|
||||
} else {
|
||||
auto fw_perm = RetrieveValues<int64_t>(attributes.at("perm"));
|
||||
|
|
@ -476,13 +531,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) {
|
|||
for (int i = 0; i < static_cast<int>(size); ++i) {
|
||||
bw_perm[fw_perm[i]] = i;
|
||||
}
|
||||
new_attributes.push_back(MakeAttribute("perm", bw_perm));
|
||||
}
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("Transpose",
|
||||
{GO(0)},
|
||||
{GI(0)},
|
||||
{MakeAttribute("perm", bw_perm)})};
|
||||
new_attributes)};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetAveragePoolGradient) {
|
||||
|
|
@ -624,30 +680,62 @@ IMPLEMENT_GRADIENT_BUILDER(GetAddSubGradient) {
|
|||
bool is_sub = (SrcNodeOpType() == "Sub");
|
||||
|
||||
const ArgDef a = I(0), b = I(1);
|
||||
|
||||
std::vector<Dimension> a_shape = GetShape(a);
|
||||
std::vector<Dimension> b_shape = GetShape(b);
|
||||
|
||||
std::vector<int64_t> a_axes, b_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes);
|
||||
|
||||
std::vector<NodeDef> output;
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(GO(0), a, GI(0), a_axes, output);
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{GO(0)},
|
||||
{GI(0)}));
|
||||
std::vector<Dimension> a_shape, b_shape;
|
||||
if (GetShape(a, a_shape).IsOK() && GetShape(b, b_shape).IsOK()) {
|
||||
std::vector<int64_t> a_axes, b_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes);
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(GO(0), a, GI(0), a_axes, output);
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{GO(0)},
|
||||
{GI(0)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
if (b_axes.size() > 0) {
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
if (b_axes.size() > 0) {
|
||||
ArgDef reshape_output = is_sub ? IA("ReshapeReduceSum_2", IType(1)) : GI(1);
|
||||
HandleBroadcasting(GO(0), b, reshape_output, b_axes, output);
|
||||
|
||||
if (is_sub) {
|
||||
output.push_back(
|
||||
NodeDef("Neg",
|
||||
{reshape_output},
|
||||
{GI(1)}));
|
||||
}
|
||||
} else {
|
||||
if (is_sub) {
|
||||
output.push_back(
|
||||
NodeDef("Neg",
|
||||
{GO(0)},
|
||||
{GI(1)}));
|
||||
} else /*is_add*/ {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{GO(0)},
|
||||
{GI(1)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//GetShape failed, build shape-independent gradient graph
|
||||
ArgDef a_axes = IA("ReduceAxes_" + a.name);
|
||||
ArgDef b_axes = IA("ReduceAxes_" + b.name);
|
||||
ArgDef A_shape = IA("Shape_" + a.name);
|
||||
ArgDef B_shape = IA("Shape_" + b.name);
|
||||
ComputeBroadcastBackwardAxesDynamic(a, b, A_shape, B_shape, &a_axes, &b_axes, output);
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
HandleBroadcastingDynamic(GO(0), a, A_shape, GI(0), a_axes, output);
|
||||
}
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
ArgDef reshape_output = is_sub ? IA("ReshapeReduceSum_2", IType(1)) : GI(1);
|
||||
HandleBroadcasting(GO(0), b, reshape_output, b_axes, output);
|
||||
HandleBroadcastingDynamic(GO(0), b, B_shape, reshape_output, b_axes, output);
|
||||
|
||||
if (is_sub) {
|
||||
output.push_back(
|
||||
|
|
@ -655,18 +743,6 @@ IMPLEMENT_GRADIENT_BUILDER(GetAddSubGradient) {
|
|||
{reshape_output},
|
||||
{GI(1)}));
|
||||
}
|
||||
} else {
|
||||
if (is_sub) {
|
||||
output.push_back(
|
||||
NodeDef("Neg",
|
||||
{GO(0)},
|
||||
{GI(1)}));
|
||||
} else /*is_add*/ {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{GO(0)},
|
||||
{GI(1)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
return output;
|
||||
|
|
@ -675,44 +751,70 @@ IMPLEMENT_GRADIENT_BUILDER(GetAddSubGradient) {
|
|||
IMPLEMENT_GRADIENT_BUILDER(GetMulGradient) {
|
||||
const ArgDef a = I(0), b = I(1);
|
||||
|
||||
std::vector<Dimension> a_shape = GetShape(a);
|
||||
std::vector<Dimension> b_shape = GetShape(b);
|
||||
std::vector<int64_t> a_axes, b_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes);
|
||||
|
||||
std::vector<NodeDef> output;
|
||||
std::vector<Dimension> a_shape, b_shape;
|
||||
if (GetShape(a, a_shape).IsOK() && GetShape(b, b_shape).IsOK()) {
|
||||
std::vector<int64_t> a_axes, b_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes);
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
output.push_back(
|
||||
NodeDef("Mul",
|
||||
{GO(0), I(1)},
|
||||
{IA("PreReduceGrad0", OType(0))}));
|
||||
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(IA("PreReduceGrad0", OType(0)), a, GI(0), a_axes, output);
|
||||
} else {
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{IA("PreReduceGrad0", OType(0))},
|
||||
{GI(0)}));
|
||||
NodeDef("Mul",
|
||||
{GO(0), I(1)},
|
||||
{IA("PreReduceGrad0", OType(0))}));
|
||||
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(IA("PreReduceGrad0", OType(0)), a, GI(0), a_axes, output);
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{IA("PreReduceGrad0", OType(0))},
|
||||
{GI(0)}));
|
||||
}
|
||||
}
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
output.push_back(
|
||||
NodeDef("Mul",
|
||||
{GO(0), I(0)},
|
||||
{IA("PreReduceGrad1", OType(0))}));
|
||||
|
||||
if (b_axes.size() > 0) {
|
||||
HandleBroadcasting(IA("PreReduceGrad1", OType(0)), b, GI(1), b_axes, output);
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{IA("PreReduceGrad1", OType(0))},
|
||||
{GI(1)}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//GetShape failed, build shape-independent gradient graph
|
||||
ArgDef a_axes = IA("ReduceAxes_" + a.name);
|
||||
ArgDef b_axes = IA("ReduceAxes_" + b.name);
|
||||
ArgDef A_shape = IA("Shape_" + a.name);
|
||||
ArgDef B_shape = IA("Shape_" + b.name);
|
||||
ComputeBroadcastBackwardAxesDynamic(a, b, A_shape, B_shape, &a_axes, &b_axes, output);
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
output.push_back(
|
||||
NodeDef("Mul",
|
||||
{GO(0), I(1)},
|
||||
{IA("PreReduceGrad0", OType(0))}));
|
||||
|
||||
HandleBroadcastingDynamic(IA("PreReduceGrad0", OType(0)), a, A_shape, GI(0), a_axes, output);
|
||||
}
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
output.push_back(
|
||||
NodeDef("Mul",
|
||||
{GO(0), I(0)},
|
||||
{IA("PreReduceGrad1", OType(0))}));
|
||||
|
||||
HandleBroadcastingDynamic(IA("PreReduceGrad1", OType(0)), b, B_shape, GI(1), b_axes, output);
|
||||
}
|
||||
}
|
||||
|
||||
if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
output.push_back(
|
||||
NodeDef("Mul",
|
||||
{GO(0), I(0)},
|
||||
{IA("PreReduceGrad1", OType(0))}));
|
||||
|
||||
if (b_axes.size() > 0) {
|
||||
HandleBroadcasting(IA("PreReduceGrad1", OType(0)), b, GI(1), b_axes, output);
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{IA("PreReduceGrad1", OType(0))},
|
||||
{GI(1)}));
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
|
|
@ -725,17 +827,32 @@ IMPLEMENT_GRADIENT_BUILDER(GetDivGradient) {
|
|||
} else if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
// Y = A / B, dA = dY / B
|
||||
const ArgDef a = I(0), b = I(1);
|
||||
std::vector<int64_t> a_axes, b_axes;
|
||||
ComputeBroadcastBackwardAxes(GetShape(a), GetShape(b), &a_axes, &b_axes);
|
||||
|
||||
std::vector<NodeDef> output;
|
||||
ArgDef tmp_grad = IA("PreReduceGrad0", OType(0));
|
||||
output.push_back(NodeDef("Div", {GO(0), I(1)}, {tmp_grad}));
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(tmp_grad, a, GI(0), a_axes, output);
|
||||
std::vector<Dimension> a_shape, b_shape;
|
||||
if (GetShape(a, a_shape).IsOK() && GetShape(b, b_shape).IsOK()) {
|
||||
std::vector<int64_t> a_axes, b_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, b_shape, &a_axes, &b_axes);
|
||||
|
||||
ArgDef tmp_grad = IA("PreReduceGrad0", OType(0));
|
||||
output.push_back(NodeDef("Div", {GO(0), I(1)}, {tmp_grad}));
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(tmp_grad, a, GI(0), a_axes, output);
|
||||
} else {
|
||||
output.push_back(NodeDef("Identity", {tmp_grad}, {GI(0)}));
|
||||
}
|
||||
} else {
|
||||
output.push_back(NodeDef("Identity", {tmp_grad}, {GI(0)}));
|
||||
//GetShape failed, build shape-independent gradient graph
|
||||
ArgDef a_axes = IA("ReduceAxes_" + a.name);
|
||||
ArgDef A_shape = IA("Shape_" + a.name);
|
||||
ArgDef B_shape = IA("Shape_" + b.name);
|
||||
|
||||
ComputeBroadcastBackwardAxesDynamic(a, b, A_shape, B_shape, &a_axes, nullptr, output);
|
||||
|
||||
ArgDef tmp_grad = IA("PreReduceGrad0", OType(0));
|
||||
output.push_back(NodeDef("Div", {GO(0), I(1)}, {tmp_grad}));
|
||||
HandleBroadcastingDynamic(tmp_grad, a, A_shape, GI(0), a_axes, output);
|
||||
}
|
||||
|
||||
return output;
|
||||
} else if (IsGradientRequiredForSrcNodeInput(1)) {
|
||||
return std::vector<NodeDef>{
|
||||
|
|
@ -907,33 +1024,52 @@ IMPLEMENT_GRADIENT_BUILDER(GetGeluGradient) {
|
|||
namespace {
|
||||
std::vector<NodeDef> GetBiasGeluGradNodes(
|
||||
bool use_approximation,
|
||||
const ArgDef& dY, const ArgDef& X, const ArgDef& B, // inputs
|
||||
const ArgDef& dX, const ArgDef& dB) { // outputs
|
||||
const auto B_shape = GetShape(B);
|
||||
ORT_ENFORCE(B_shape.size() == 1, "B must have exactly one dimension.");
|
||||
const ArgDef& dY, const ArgDef& X, const ArgDef& B, // inputs
|
||||
const ArgDef& dX, const ArgDef& dB, // outputs
|
||||
const ArgDef& b_axes, const ArgDef& b_shape, const ArgDef& x_shape) { //intermediate args
|
||||
std::vector<Dimension> B_shape, X_shape;
|
||||
if (GetShape(B, B_shape).IsOK() && GetShape(X, X_shape).IsOK()) {
|
||||
ORT_ENFORCE(B_shape.size() == 1, "B must have exactly one dimension.");
|
||||
|
||||
const std::vector<int64_t> B_axes = [&B_shape, &X]() {
|
||||
std::vector<int64_t> result{};
|
||||
ComputeBroadcastBackwardAxes(B_shape, GetShape(X), &result, nullptr);
|
||||
const std::vector<int64_t> B_axes = [&B_shape, &X_shape]() {
|
||||
std::vector<int64_t> result{};
|
||||
ComputeBroadcastBackwardAxes(B_shape, X_shape, &result, nullptr);
|
||||
return result;
|
||||
}();
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{use_approximation ? "BiasFastGeluGrad_dX" : "BiasGeluGrad_dX", kMSDomain, 1},
|
||||
{dY, X, B},
|
||||
{dX}),
|
||||
NodeDef("ReduceSum",
|
||||
{dX},
|
||||
{dB},
|
||||
{{"keepdims", MakeAttribute("keepdims", int64_t{0})},
|
||||
{"axes", MakeAttribute("axes", B_axes)}})};
|
||||
} else {
|
||||
std::vector<NodeDef> result;
|
||||
ComputeBroadcastBackwardAxesDynamic(B, X, b_shape, x_shape, &b_axes, nullptr, result);
|
||||
result.push_back(
|
||||
NodeDef(OpDef{use_approximation ? "BiasFastGeluGrad_dX" : "BiasGeluGrad_dX", kMSDomain, 1},
|
||||
{dY, X, B},
|
||||
{dX}));
|
||||
result.push_back(
|
||||
NodeDef(OpDef{"ReduceSumTraining", kMSDomain, 1},
|
||||
{dX,
|
||||
b_axes},
|
||||
{dB},
|
||||
{{"keepdims", MakeAttribute("keepdims", int64_t{0})}}));
|
||||
return result;
|
||||
}();
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{use_approximation ? "BiasFastGeluGrad_dX" : "BiasGeluGrad_dX", kMSDomain, 1},
|
||||
{dY, X, B},
|
||||
{dX}),
|
||||
NodeDef("ReduceSum",
|
||||
{dX},
|
||||
{dB},
|
||||
{{"keepdims", MakeAttribute("keepdims", int64_t{0})},
|
||||
{"axes", MakeAttribute("axes", B_axes)}})};
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetBiasGeluGradient) {
|
||||
const auto dY = GO(0), X = I(0), B = I(1),
|
||||
dX = GI(0), dB = GI(1);
|
||||
return GetBiasGeluGradNodes(false, dY, X, B, dX, dB);
|
||||
ArgDef b_axes = IA("ReduceAxes_" + B.name);
|
||||
ArgDef b_shape = IA("Shape_" + B.name);
|
||||
ArgDef x_shape = IA("Shape_" + X.name);
|
||||
return GetBiasGeluGradNodes(false, dY, X, B, dX, dB, b_axes, b_shape, x_shape);
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) {
|
||||
|
|
@ -944,7 +1080,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetFastGeluGradient) {
|
|||
// FastGeluGrad doesn't support bias - it needs to be composed with other ops
|
||||
const auto B = I(1),
|
||||
dB = GI(1);
|
||||
return GetBiasGeluGradNodes(true, dY, X, B, dX, dB);
|
||||
ArgDef b_axes = IA("ReduceAxes_" + B.name);
|
||||
ArgDef b_shape = IA("Shape_" + B.name);
|
||||
ArgDef x_shape = IA("Shape_" + X.name);
|
||||
return GetBiasGeluGradNodes(true, dY, X, B, dX, dB, b_axes, b_shape, x_shape);
|
||||
}
|
||||
if (num_src_node_inputs == 1) { // without bias
|
||||
return std::vector<NodeDef>{
|
||||
|
|
@ -1070,19 +1209,30 @@ IMPLEMENT_GRADIENT_BUILDER(GetRecvGradient) {
|
|||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) {
|
||||
ArgDef a = I(0), y = O(0);
|
||||
std::vector<Dimension> a_shape = GetShape(a);
|
||||
std::vector<Dimension> y_shape = GetShape(y);
|
||||
std::vector<int64_t> a_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, y_shape, &a_axes, nullptr);
|
||||
|
||||
std::vector<NodeDef> output;
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(GO(0), a, GI(0), a_axes, output);
|
||||
|
||||
std::vector<Dimension> a_shape, y_shape;
|
||||
if (GetShape(a, a_shape).IsOK() && GetShape(y, y_shape).IsOK()) {
|
||||
std::vector<int64_t> a_axes;
|
||||
ComputeBroadcastBackwardAxes(a_shape, y_shape, &a_axes, nullptr);
|
||||
|
||||
if (a_axes.size() > 0) {
|
||||
HandleBroadcasting(GO(0), a, GI(0), a_axes, output);
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{GO(0)},
|
||||
{GI(0)}));
|
||||
}
|
||||
} else {
|
||||
output.push_back(
|
||||
NodeDef("Identity",
|
||||
{GO(0)},
|
||||
{GI(0)}));
|
||||
//GetShape failed, build shape-independent gradient graph
|
||||
ArgDef a_axes = IA("ReduceAxes_" + a.name);
|
||||
ArgDef A_shape = IA("Shape_" + a.name);
|
||||
ArgDef Y_shape = IA("Shape_" + y.name);
|
||||
|
||||
ComputeBroadcastBackwardAxesDynamic(a, y, A_shape, Y_shape, &a_axes, nullptr, output);
|
||||
|
||||
HandleBroadcastingDynamic(GO(0), a, A_shape, GI(0), a_axes, output);
|
||||
}
|
||||
|
||||
return output;
|
||||
|
|
|
|||
|
|
@ -85,17 +85,43 @@ void ComputeBroadcastBackwardAxes(
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<Dimension> GetShape(const ArgDef& arg_def) {
|
||||
ORT_ENFORCE(arg_def.type_proto
|
||||
&& arg_def.type_proto->has_tensor_type()
|
||||
&& arg_def.type_proto->tensor_type().has_shape(),
|
||||
"During GetShape, ", arg_def.name, "'s shape is null.");
|
||||
std::vector<Dimension> shape;
|
||||
Status GetShape(const ArgDef& arg_def, std::vector<Dimension>& shape) {
|
||||
shape.clear();
|
||||
ORT_RETURN_IF_NOT(arg_def.type_proto && arg_def.type_proto->has_tensor_type() && arg_def.type_proto->tensor_type().has_shape(),
|
||||
"During GetShape, ", arg_def.name, "'s shape is null.");
|
||||
const auto& dims = arg_def.type_proto->tensor_type().shape().dim();
|
||||
for (auto dim = dims.begin(); dim < dims.end(); dim++) {
|
||||
shape.push_back(*dim);
|
||||
}
|
||||
return shape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ComputeBroadcastBackwardAxesDynamic(const ArgDef& a,
|
||||
const ArgDef& b,
|
||||
const ArgDef& a_shape,
|
||||
const ArgDef& b_shape,
|
||||
const ArgDef* a_axes,
|
||||
const ArgDef* b_axes,
|
||||
std::vector<NodeDef>& output) {
|
||||
output.push_back(
|
||||
NodeDef("Shape",
|
||||
{a},
|
||||
{a_shape}));
|
||||
|
||||
output.push_back(
|
||||
NodeDef("Shape",
|
||||
{b},
|
||||
{b_shape}));
|
||||
|
||||
ArgDef a_op = ArgDef(""), b_op = ArgDef("");
|
||||
if (a_axes)
|
||||
a_op = *a_axes;
|
||||
if (b_axes)
|
||||
b_op = *b_axes;
|
||||
output.push_back(
|
||||
NodeDef(OpDef{"BroadcastGradientArgs", kMSDomain, 1},
|
||||
{a_shape, b_shape},
|
||||
{a_op, b_op}));
|
||||
}
|
||||
|
||||
void GradientBuilderBase::HandleBroadcasting(const ArgDef& input_grad,
|
||||
|
|
@ -104,9 +130,9 @@ void GradientBuilderBase::HandleBroadcasting(const ArgDef& input_grad,
|
|||
const std::vector<int64_t>& reduce_axes,
|
||||
std::vector<NodeDef>& output) const {
|
||||
std::unordered_set<size_t> reduce_axes_set(reduce_axes.begin(), reduce_axes.end());
|
||||
std::vector<Dimension> reduced_shape;
|
||||
auto input_grad_shape = GetShape(input_grad);
|
||||
auto target_shape = GetShape(target);
|
||||
std::vector<Dimension> reduced_shape, input_grad_shape, target_shape;
|
||||
ORT_ENFORCE(GetShape(input_grad, input_grad_shape).IsOK());
|
||||
ORT_ENFORCE(GetShape(target, target_shape).IsOK());
|
||||
|
||||
bool keep_dims = (input_grad_shape.size() == target_shape.size());
|
||||
|
||||
|
|
@ -165,5 +191,26 @@ void GradientBuilderBase::HandleBroadcasting(const ArgDef& input_grad,
|
|||
}
|
||||
}
|
||||
|
||||
void GradientBuilderBase::HandleBroadcastingDynamic(const ArgDef& input_grad,
|
||||
const ArgDef& target,
|
||||
const ArgDef& target_shape,
|
||||
const ArgDef& output_grad,
|
||||
const ArgDef& reduce_axes,
|
||||
std::vector<NodeDef>& output) const {
|
||||
ArgDef reduce_grad_arg = IA("ReduceSumTraining_" + input_grad.name + "_for_" + target.name);
|
||||
output.push_back(
|
||||
NodeDef(OpDef{"ReduceSumTraining", kMSDomain, 1},
|
||||
{input_grad,
|
||||
reduce_axes},
|
||||
{reduce_grad_arg},
|
||||
{{"keepdims", ONNX_NAMESPACE::MakeAttribute("keepdims", int64_t(1))},
|
||||
{"noop_with_empty_axes", ONNX_NAMESPACE::MakeAttribute("noop_with_empty_axes", int64_t(1))}}));
|
||||
|
||||
output.push_back(
|
||||
NodeDef("Reshape",
|
||||
{reduce_grad_arg, target_shape},
|
||||
{output_grad}));
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -21,7 +21,15 @@ void ComputeBroadcastBackwardAxes(
|
|||
std::vector<int64_t>* A_axes,
|
||||
std::vector<int64_t>* B_axes);
|
||||
|
||||
std::vector<Dimension> GetShape(const ArgDef& arg_def);
|
||||
void ComputeBroadcastBackwardAxesDynamic(const ArgDef& a,
|
||||
const ArgDef& b,
|
||||
const ArgDef& a_shape,
|
||||
const ArgDef& b_shape,
|
||||
const ArgDef* a_axes,
|
||||
const ArgDef* b_axes,
|
||||
std::vector<NodeDef>& output);
|
||||
|
||||
Status GetShape(const ArgDef& arg_def, std::vector<Dimension>& shape);
|
||||
|
||||
typedef std::vector<NodeDef> GradientDef;
|
||||
|
||||
|
|
@ -175,6 +183,13 @@ class GradientBuilderBase {
|
|||
const std::vector<int64_t>& reduce_axes,
|
||||
std::vector<NodeDef>& output) const;
|
||||
|
||||
void HandleBroadcastingDynamic(const ArgDef& input_grad,
|
||||
const ArgDef& target,
|
||||
const ArgDef& target_shape,
|
||||
const ArgDef& output_grad,
|
||||
const ArgDef& reduce_axes,
|
||||
std::vector<NodeDef>& output) const;
|
||||
|
||||
private:
|
||||
friend class GradientGraphBuilder;
|
||||
|
||||
|
|
|
|||
|
|
@ -1128,8 +1128,8 @@ Example 4:
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SplitTraining)
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(SplitTraining)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
|
|
@ -1166,7 +1166,7 @@ Example 4:
|
|||
return;
|
||||
}
|
||||
std::vector<int64_t> split = ParseData<int64_t>(split_proto);
|
||||
|
||||
|
||||
if (!ctx.getInputType(0)->tensor_type().has_shape()) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1206,7 +1206,7 @@ Example 4:
|
|||
->mutable_dim(axis)
|
||||
->set_dim_value(split[i]);
|
||||
}
|
||||
|
||||
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ConcatTraining)
|
||||
|
|
@ -1408,8 +1408,8 @@ Example 4:
|
|||
"Output is an empty vector when no reduction is necessary for the corresponding input.")
|
||||
.Input(0, "a_shape", "The 1st input shape as Tensor.", "T")
|
||||
.Input(1, "b_shape", "The 2nd input shape as Tensor.", "T")
|
||||
.Output(0, "a_axes", "The reduction axes for 1st input, last to first.", "T")
|
||||
.Output(1, "b_axes", "The reduction axes for 2nd input, last to first.", "T")
|
||||
.Output(0, "a_axes", "The reduction axes for 1st input, last to first.", "T", OpSchema::Optional)
|
||||
.Output(1, "b_axes", "The reduction axes for 2nd input, last to first.", "T", OpSchema::Optional)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(int64)"},
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||
#include "orttraining/core/graph/gradient_config.h"
|
||||
#include "test/util/include/test_random_seed.h"
|
||||
#include <random>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
|
|
@ -112,12 +111,14 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeTheoreticalJacobianTransp
|
|||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
std::vector<std::vector<JAC_T>>* jacobian_ts,
|
||||
const std::vector<AttributeProto>& attributes) {
|
||||
const std::vector<AttributeProto>& attributes,
|
||||
bool add_shape) {
|
||||
size_t y_num = y_infos.size();
|
||||
size_t x_num = x_infos.size();
|
||||
|
||||
// build the graph once and reuse it later in the looping logic
|
||||
GradientOpTester op_session(op_def.type.c_str(), x_infos, y_infos, op_def.opset_version, op_def.domain.c_str(), false);
|
||||
op_session.AddShapeToTensorData(add_shape);
|
||||
InitOpTesterWithGradGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes);
|
||||
|
||||
// currently only supported scalar valued fns - and complex types are not supported
|
||||
|
|
@ -287,7 +288,6 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::InitOpTesterWithGradGraph(
|
|||
const std::vector<AttributeProto>& attributes) {
|
||||
std::unordered_map<std::string, int> extra_domain_to_version{{kMSDomain, 1}, {kOnnxDomain, 9}};
|
||||
InitOpTesterWithGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes, extra_domain_to_version);
|
||||
|
||||
// build grad graph
|
||||
auto p_model = op_session.GetModelCache();
|
||||
auto& graph = p_model->MainGraph();
|
||||
|
|
@ -328,13 +328,15 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeNumericJacobianTranspose(
|
|||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
std::vector<std::vector<JAC_T>>* jacobian_ts,
|
||||
const std::vector<AttributeProto>& attributes) {
|
||||
const std::vector<AttributeProto>& attributes,
|
||||
bool add_shape) {
|
||||
size_t y_num = y_infos.size();
|
||||
size_t x_num = x_infos.size();
|
||||
X_T x_delta = static_cast<X_T>(delta);
|
||||
|
||||
// build the graph once and reuse it later in the looping logic
|
||||
OpTester op_session(op_def.type.c_str(), op_def.opset_version, op_def.domain.c_str(), false);
|
||||
op_session.AddShapeToTensorData(add_shape);
|
||||
InitOpTesterWithGraph(op_session, x_infos, y_infos, x_datas, y_datas, attributes);
|
||||
|
||||
for (int x_idx = 0; x_idx < static_cast<int>(x_num); x_idx++) {
|
||||
|
|
@ -433,11 +435,11 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientErrorInternal(
|
|||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
JAC_T* max_error,
|
||||
const std::vector<AttributeProto>& attributes,
|
||||
bool check_not_have_gradient) {
|
||||
bool check_not_have_gradient,
|
||||
bool check_not_have_shape_inferencing) {
|
||||
// Initialize numeric Jacobian to zeros.
|
||||
std::vector<std::vector<JAC_T>> jacobian_ns;
|
||||
InitJacobians(x_infos, y_infos, &jacobian_ns);
|
||||
|
||||
// Compute numeric Jacobian.
|
||||
ORT_RETURN_IF_ERROR(ComputeNumericJacobianTranspose(
|
||||
op_def, x_infos, y_infos, JAC_T{1e-3f}, x_datas, y_datas, &jacobian_ns, attributes));
|
||||
|
|
@ -445,58 +447,60 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientErrorInternal(
|
|||
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||
*max_error = 0.0;
|
||||
|
||||
// It is necessary to test for inputs with or without gradient.
|
||||
// We simply set each input without gradient to test the rest inputs' gradient.
|
||||
// In the last loop it tests for the case where all inputs are with gradient.
|
||||
size_t total_gradient_variations = check_not_have_gradient ? x_infos.size() + 1 : 1;
|
||||
for (size_t x_gradient_variation = 0; x_gradient_variation < total_gradient_variations; x_gradient_variation++) {
|
||||
// Initialize theoretical Jacobians to zeros.
|
||||
std::vector<std::vector<JAC_T>> jacobian_ts;
|
||||
InitJacobians(x_infos, y_infos, &jacobian_ts);
|
||||
int num_grad_builder_checks = check_not_have_shape_inferencing ? 2 : 1;
|
||||
bool add_shape = true;
|
||||
for (int i = 0; i < num_grad_builder_checks; i++, add_shape = false) {
|
||||
// It is necessary to test for inputs with or without gradient.
|
||||
// We simply set each input without gradient to test the rest inputs' gradient.
|
||||
// In the last loop it tests for the case where all inputs are with gradient.
|
||||
size_t total_gradient_variations = check_not_have_gradient ? x_infos.size() + 1 : 1;
|
||||
for (size_t x_gradient_variation = 0; x_gradient_variation < total_gradient_variations; x_gradient_variation++) {
|
||||
// Initialize theoretical Jacobians to zeros.
|
||||
std::vector<std::vector<JAC_T>> jacobian_ts;
|
||||
InitJacobians(x_infos, y_infos, &jacobian_ts);
|
||||
|
||||
std::vector<TensorInfo> x_infos_gradient_variation = x_infos;
|
||||
std::vector<TensorInfo> x_infos_gradient_variation = x_infos;
|
||||
|
||||
if (check_not_have_gradient && x_gradient_variation < x_infos.size())
|
||||
x_infos_gradient_variation[x_gradient_variation].has_gradient = false;
|
||||
if (check_not_have_gradient && x_gradient_variation < x_infos.size())
|
||||
x_infos_gradient_variation[x_gradient_variation].has_gradient = false;
|
||||
|
||||
if (std::all_of(x_infos_gradient_variation.cbegin(), x_infos_gradient_variation.cend(),
|
||||
[](const TensorInfo& info) { return !info.has_gradient; }))
|
||||
// a gradient node cannot get created without any has_gradient node.
|
||||
continue;
|
||||
|
||||
// Compute theoretical Jacobian.
|
||||
ORT_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose(
|
||||
op_def, x_infos_gradient_variation, y_infos, x_datas, y_datas, &jacobian_ts, attributes));
|
||||
|
||||
// We have numeric jacobians regardless of has_gradient (computed once).
|
||||
// We only have theoretical jacobians for those has_gradient.
|
||||
// Theoretical jacobians are 0 for those not has_gradient.
|
||||
int64_t j = 0;
|
||||
for (auto& x_info : x_infos_gradient_variation) {
|
||||
if (!x_info.has_gradient) {
|
||||
// TODO: These 4 test failed at following ORT_ENFORCE. need investigate before enable it.
|
||||
//GradientCheckerTest.MatMulGrad
|
||||
//GradientCheckerTest.GemmGrad
|
||||
//GradientCheckerTest.GatherNDGrad_repeat_float_data
|
||||
//GradientCheckerTest.GatherNDGrad_unique_float_data
|
||||
//auto jac_t = jacobian_ts[j];
|
||||
//ORT_ENFORCE(std::all_of(
|
||||
// &jac_t[0], &jac_t[0] + x_info.shape.Size(), [](auto dx) { return dx == 0; }));
|
||||
j += x_info.shape.Size();
|
||||
} else {
|
||||
for (int r = 0; r < x_info.shape.Size(); j++, r++) {
|
||||
auto jac_t = jacobian_ts[j];
|
||||
auto jac_n = jacobian_ns[j];
|
||||
for (size_t i = 0; i < jac_t.size(); i++) {
|
||||
// dy_i/dx_j for x with gradient.
|
||||
auto cur_error = std::fabs(jac_t[i] - jac_n[i]);
|
||||
// Treat any NaN as max_error and immediately return.
|
||||
// (Note that std::max may ignore NaN arguments.)
|
||||
if (std::isnan(cur_error)) {
|
||||
*max_error = cur_error;
|
||||
return Status::OK();
|
||||
if (std::all_of(x_infos_gradient_variation.cbegin(), x_infos_gradient_variation.cend(),
|
||||
[](const TensorInfo& info) { return !info.has_gradient; }))
|
||||
// a gradient node cannot get created without any has_gradient node.
|
||||
continue;
|
||||
// Compute theoretical Jacobian.
|
||||
ORT_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose(
|
||||
op_def, x_infos_gradient_variation, y_infos, x_datas, y_datas, &jacobian_ts, attributes, add_shape));
|
||||
// We have numeric jacobians regardless of has_gradient (computed once).
|
||||
// We only have theoretical jacobians for those has_gradient.
|
||||
// Theoretical jacobians are 0 for those not has_gradient.
|
||||
int64_t j = 0;
|
||||
for (auto& x_info : x_infos_gradient_variation) {
|
||||
if (!x_info.has_gradient) {
|
||||
// TODO: These 4 test failed at following ORT_ENFORCE. need investigate before enable it.
|
||||
//GradientCheckerTest.MatMulGrad
|
||||
//GradientCheckerTest.GemmGrad
|
||||
//GradientCheckerTest.GatherNDGrad_repeat_float_data
|
||||
//GradientCheckerTest.GatherNDGrad_unique_float_data
|
||||
//auto jac_t = jacobian_ts[j];
|
||||
//ORT_ENFORCE(std::all_of(
|
||||
// &jac_t[0], &jac_t[0] + x_info.shape.Size(), [](auto dx) { return dx == 0; }));
|
||||
j += x_info.shape.Size();
|
||||
} else {
|
||||
for (int r = 0; r < x_info.shape.Size(); j++, r++) {
|
||||
auto jac_t = jacobian_ts[j];
|
||||
auto jac_n = jacobian_ns[j];
|
||||
for (size_t k = 0; k < jac_t.size(); k++) {
|
||||
// dy_i/dx_j for x with gradient.
|
||||
auto cur_error = std::fabs(jac_t[k] - jac_n[k]);
|
||||
// Treat any NaN as max_error and immediately return.
|
||||
// (Note that std::max may ignore NaN arguments.)
|
||||
if (std::isnan(cur_error)) {
|
||||
*max_error = cur_error;
|
||||
return Status::OK();
|
||||
}
|
||||
*max_error = std::max(*max_error, cur_error);
|
||||
}
|
||||
*max_error = std::max(*max_error, cur_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -512,7 +516,8 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientError(
|
|||
const std::vector<TensorInfo>& y_infos,
|
||||
JAC_T* max_error,
|
||||
const std::vector<AttributeProto>& attributes,
|
||||
bool check_not_have_gradient /* = true*/) {
|
||||
bool check_not_have_gradient, /* = true*/
|
||||
bool check_not_have_shape_inferencing /* = false*/) {
|
||||
// TODO: Consider varying mean and variance
|
||||
float scale = 5.f;
|
||||
float mean = 0.f;
|
||||
|
|
@ -542,7 +547,7 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientError(
|
|||
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error,
|
||||
attributes, check_not_have_gradient);
|
||||
attributes, check_not_have_gradient, check_not_have_shape_inferencing);
|
||||
}
|
||||
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
|
|
@ -552,7 +557,9 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientError(
|
|||
const std::vector<TensorInfo>& y_infos,
|
||||
JAC_T* max_error,
|
||||
std::vector<std::vector<X_T>> x_datas,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes) {
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes,
|
||||
bool check_not_have_gradient, /* = true*/
|
||||
bool check_not_have_shape_inferencing /* = false*/) {
|
||||
// Generate dummy placeholders with zero for y_datas
|
||||
std::vector<std::vector<Y_T>> y_datas(y_infos.size());
|
||||
for (size_t i = 0; i < y_infos.size(); i++) {
|
||||
|
|
@ -560,7 +567,8 @@ inline Status GradientChecker<X_T, Y_T, JAC_T>::ComputeGradientError(
|
|||
}
|
||||
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error, attributes);
|
||||
return ComputeGradientErrorInternal(op_def, x_infos, y_infos, &x_datas, &y_datas, max_error,
|
||||
attributes, check_not_have_gradient, check_not_have_shape_inferencing);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GRAD_ERR_TYPE(X_T, Y_T, JAC_T) \
|
||||
|
|
|
|||
|
|
@ -76,7 +76,9 @@ class GradientChecker {
|
|||
// because the gradient op does not handle the case. We have to use this flag
|
||||
// to disable check for not having gradient cases in order to pass those test.
|
||||
// Remove this flag when the gradient op is fixed.
|
||||
bool check_not_have_gradient = true);
|
||||
bool check_not_have_gradient = true,
|
||||
// Also check gradient builder for op for cases where input shapes are not available
|
||||
bool check_not_have_shape_inferencing = false);
|
||||
|
||||
Status ComputeGradientError(
|
||||
const training::OpDef& op_def,
|
||||
|
|
@ -84,7 +86,14 @@ class GradientChecker {
|
|||
const std::vector<TensorInfo>& y_infos,
|
||||
JAC_T* max_error,
|
||||
std::vector<std::vector<X_T>> x_datas,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes = {});
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes = {},
|
||||
// TODO: Ideally it shall check for not has_gradient cases. But some tests are failing
|
||||
// because the gradient op does not handle the case. We have to use this flag
|
||||
// to disable check for not having gradient cases in order to pass those test.
|
||||
// Remove this flag when the gradient op is fixed.
|
||||
bool check_not_have_gradient = true,
|
||||
// Also check gradient builder for op for cases where input shapes are not available
|
||||
bool check_not_have_shape_inferencing = false);
|
||||
|
||||
private:
|
||||
Status InitJacobians(const std::vector<TensorInfo>& x_infos,
|
||||
|
|
@ -92,10 +101,10 @@ class GradientChecker {
|
|||
std::vector<std::vector<JAC_T>>* jacobians);
|
||||
|
||||
std::vector<OrtValue> EvaluateFunctionAtInput(OpTester& op_tester,
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
const std::vector<TensorInfo>& y_infos,
|
||||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas);
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
const std::vector<TensorInfo>& y_infos,
|
||||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas);
|
||||
|
||||
Status InitOpTesterWithGraph(OpTester& op_tester,
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
|
|
@ -106,11 +115,11 @@ class GradientChecker {
|
|||
const std::unordered_map<std::string, int>& extra_domain_to_version = {});
|
||||
|
||||
Status InitOpTesterWithGradGraph(OpTester& op_tester,
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
const std::vector<TensorInfo>& y_infos,
|
||||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes);
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
const std::vector<TensorInfo>& y_infos,
|
||||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes);
|
||||
|
||||
Status ComputeTheoreticalJacobianTranspose(const training::OpDef& op_def,
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
|
|
@ -118,7 +127,8 @@ class GradientChecker {
|
|||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
std::vector<std::vector<JAC_T>>* jacobian_ts,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes);
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes,
|
||||
bool add_shape = true);
|
||||
|
||||
Status ComputeNumericJacobianTranspose(const training::OpDef& op_def,
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
|
|
@ -127,7 +137,8 @@ class GradientChecker {
|
|||
std::vector<std::vector<X_T>>* x_datas,
|
||||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
std::vector<std::vector<JAC_T>>* jacobian_ts,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes);
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes,
|
||||
bool add_shape = true);
|
||||
|
||||
Status ComputeGradientErrorInternal(const training::OpDef& op_name,
|
||||
const std::vector<TensorInfo>& x_infos,
|
||||
|
|
@ -136,7 +147,8 @@ class GradientChecker {
|
|||
std::vector<std::vector<Y_T>>* y_datas,
|
||||
JAC_T* max_error,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes,
|
||||
bool check_not_have_gradient = true);
|
||||
bool check_not_have_gradient = true,
|
||||
bool check_not_have_shape_inferencing = false);
|
||||
};
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -101,10 +101,12 @@ TEST(GradientCheckerTest, SqrtGrad) {
|
|||
}
|
||||
|
||||
void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
||||
std::function<float(float)>* transformer = nullptr) {
|
||||
std::function<float(float)>* transformer = nullptr,
|
||||
bool check_not_have_shape_inferencing = true) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{op_type};
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
|
||||
|
||||
//shape(A) = (2, 3, 4, 5), shape(B) = (2, 3, 4, 5), ==> shape(result) = (2, 3, 4, 5)
|
||||
{
|
||||
|
|
@ -112,7 +114,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{2, 3, 4, 5}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -122,7 +125,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -132,7 +136,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{2, 3, 4, 5}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -142,7 +147,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{5}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -152,7 +158,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{2, 3, 4, 5}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -162,7 +169,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{2, 3, 1, 1}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -172,7 +180,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{2, 1, 1, 1}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -182,7 +191,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{1, 3, 4, 1}, true, transformer};
|
||||
TensorInfo Y_info{{2, 3, 4, 5}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -193,7 +203,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{4, 2, 1, 1}, true, transformer, DataTypeImpl::GetTensorType<float>(), {"4", "2", "1", "1"}};
|
||||
TensorInfo Y_info{{4, 2, 1, 3}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
// symbolic broadcast + numeric broadcast
|
||||
|
|
@ -203,7 +214,8 @@ void TestBroadcastableBinaryOpGrad(const std::string& op_type,
|
|||
TensorInfo B_info{{4, 1, 1, 3}, true, transformer, DataTypeImpl::GetTensorType<float>(), {"batch", "1", "1", "seq"}};
|
||||
TensorInfo Y_info{{4, 2, 3, 3}};
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {A_info, B_info}, {Y_info}, &max_error,
|
||||
attributes, true, check_not_have_shape_inferencing);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
|
@ -259,52 +271,61 @@ TEST(GradientCheckerTest, MatMulGrad) {
|
|||
const float error_tolerance = 1e-1f;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"MatMul"};
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
|
||||
|
||||
// 2D x 2D
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}}, {{2, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}}, {{2, 3}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 3D x 3D
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 3D x 2D
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {4, 3}}, {{2, 3, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4}, {4, 3}}, {{2, 3, 3}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D x 3D
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{3, 4}, {2, 4, 3}}, {{2, 3, 3}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 4D x 4D
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {2, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {2, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 4D x 2D
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {5, 4}}, {{2, 3, 4, 4}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {5, 4}}, {{2, 3, 4, 4}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 4D x 3D
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {3, 5, 4}}, {{2, 3, 4, 4}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 3, 4, 5}, {3, 5, 4}}, {{2, 3, 4, 4}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 4D x 4D with broadcast
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 1, 4, 5}, {1, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 1, 4, 5}, {1, 3, 5, 4}}, {{2, 3, 4, 4}}, &max_error,
|
||||
attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
}
|
||||
|
|
@ -324,54 +345,55 @@ TEST(GradientCheckerTest, GemmGrad) {
|
|||
const float error_tolerance = 2e-2f;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Gemm"};
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
|
||||
|
||||
// Single Batch with Scalar Bias
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {}}, {{1, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {}}, {{1, 3}}, &max_error, attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// Single Batch with Vector Bias
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {3}}, {{1, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {3}}, {{1, 3}}, &max_error, attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// Non-Single Batch with Scalar Bias
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {}}, {{2, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {}}, {{2, 3}}, &max_error, attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// Non-Single Batch with Vector Bias
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {3}}, {{2, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {3}}, {{2, 3}}, &max_error, attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// Non-Single Batch with Broadcast Bias
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {1, 3}}, {{2, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {1, 3}}, {{2, 3}}, &max_error, attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// Non-Single Batch with Non-BroadcastBias
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {2, 3}}, {{2, 3}}, &max_error);
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {2, 3}}, {{2, 3}}, &max_error, attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// TransA
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 2}, {4, 3}, {3}}, {{2, 3}}, &max_error,
|
||||
{MakeAttribute("transA", int64_t(1))});
|
||||
{MakeAttribute("transA", int64_t(1))}, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// TransB
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {3, 4}, {3}}, {{2, 3}}, &max_error,
|
||||
{MakeAttribute("transB", int64_t(1))});
|
||||
{MakeAttribute("transB", int64_t(1))}, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -379,7 +401,8 @@ TEST(GradientCheckerTest, GemmGrad) {
|
|||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{4, 2}, {3, 4}, {3}}, {{2, 3}}, &max_error,
|
||||
{MakeAttribute("transA", int64_t(1)),
|
||||
MakeAttribute("transB", int64_t(1))});
|
||||
MakeAttribute("transB", int64_t(1))},
|
||||
true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -387,7 +410,8 @@ TEST(GradientCheckerTest, GemmGrad) {
|
|||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {2, 3}}, {{2, 3}}, &max_error,
|
||||
{MakeAttribute("alpha", 0.7f),
|
||||
MakeAttribute("beta", 5.0f)});
|
||||
MakeAttribute("beta", 5.0f)},
|
||||
true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -395,7 +419,8 @@ TEST(GradientCheckerTest, GemmGrad) {
|
|||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{2, 4}, {4, 3}, {3}}, {{2, 3}}, &max_error,
|
||||
{MakeAttribute("alpha", 0.7f),
|
||||
MakeAttribute("beta", 5.0f)});
|
||||
MakeAttribute("beta", 5.0f)},
|
||||
true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
}
|
||||
|
|
@ -847,7 +872,9 @@ TEST(GradientCheckerTest, TransposeGrad) {
|
|||
{
|
||||
TensorShape x_shape({2, 3, 4});
|
||||
TensorShape y_shape({4, 3, 2});
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape}, {y_shape}, &max_error);
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape}, {y_shape}, &max_error,
|
||||
attributes, true, true /*also test w/o shape inferencing */);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
|
|
@ -1377,10 +1404,12 @@ void TestBiasGeluGrad(const std::string& op_type, const std::string& domain, int
|
|||
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{op_type, domain, opset_version};
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
|
||||
|
||||
float max_error;
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(
|
||||
op_def, {input_shape, bias_shape}, {input_shape}, &max_error));
|
||||
op_def, {input_shape, bias_shape}, {input_shape}, &max_error,
|
||||
attributes, true, true));
|
||||
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
|
@ -1852,6 +1881,7 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"Expand"};
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
|
||||
|
||||
//input_shape = (2, 3, 1), target_shape = (2, 3, 4) ==> shape(result) = (2, 3, 4)
|
||||
{
|
||||
|
|
@ -1861,7 +1891,7 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
|
||||
TensorInfo y_info({2, 3, 4}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -1873,7 +1903,7 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
|
||||
TensorInfo y_info({2, 3, 4}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -1885,7 +1915,7 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
|
||||
TensorInfo y_info({2, 3, 4}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -1897,7 +1927,7 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
|
||||
TensorInfo y_info({2, 3, 1}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -1909,7 +1939,7 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
|
||||
TensorInfo y_info({4, 5, 2, 3}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
|
||||
|
|
@ -1921,7 +1951,7 @@ TEST(GradientCheckerTest, ExpandGrad) {
|
|||
|
||||
TensorInfo y_info({4, 5, 2, 3}, true);
|
||||
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas);
|
||||
gradient_checker.ComputeGradientError(op_def, {x_info, shape_info}, {y_info}, &max_error, x_datas, attributes, true, true);
|
||||
EXPECT_IS_TINY(max_error);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,16 +24,18 @@ constexpr auto k_opset_version = 1;
|
|||
void RunBroadcastGradientArgsTest(const char* op,
|
||||
const std::vector<int64_t>& A_shape_tensor,
|
||||
const std::vector<int64_t>& B_shape_tensor,
|
||||
const std::vector<int64_t>& A_axes_expected,
|
||||
const std::vector<int64_t>& B_axes_expected,
|
||||
const std::vector<int64_t>* A_axes_expected,
|
||||
const std::vector<int64_t>* B_axes_expected,
|
||||
bool fail = false) {
|
||||
OpTester t{op, k_opset_version, kMSDomain};
|
||||
|
||||
t.AddInput("a_shape", {static_cast<int64_t>(A_shape_tensor.size())}, A_shape_tensor);
|
||||
t.AddInput("b_shape", {static_cast<int64_t>(B_shape_tensor.size())}, B_shape_tensor);
|
||||
|
||||
t.AddOutput<int64_t>("a_axes", {static_cast<int64_t>(A_axes_expected.size())}, A_axes_expected);
|
||||
t.AddOutput<int64_t>("b_axes", {static_cast<int64_t>(B_axes_expected.size())}, B_axes_expected);
|
||||
if (A_axes_expected)
|
||||
t.AddOutput<int64_t>("a_axes", {static_cast<int64_t>(A_axes_expected->size())}, *A_axes_expected);
|
||||
if (B_axes_expected)
|
||||
t.AddOutput<int64_t>("b_axes", {static_cast<int64_t>(B_axes_expected->size())}, *B_axes_expected);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
|
|
@ -48,48 +50,72 @@ void RunBroadcastGradientArgsTest(const char* op,
|
|||
// BroadcastGradientArgs
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic) {
|
||||
std::vector<int64_t> A_axes_expected = {};
|
||||
std::vector<int64_t> B_axes_expected = {1, 0};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 16, 1024, 1024}, {1, 1, 1024, 1024},
|
||||
{}, {1, 0});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_both_valid_op) {
|
||||
std::vector<int64_t> A_axes_expected = {2};
|
||||
std::vector<int64_t> B_axes_expected = {1, 0};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 16, 1, 1024}, {1, 1, 1024, 1024},
|
||||
{2}, {1, 0});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_no_bcast) {
|
||||
std::vector<int64_t> A_axes_expected = {};
|
||||
std::vector<int64_t> B_axes_expected = {};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 3, 4, 5}, {2, 3, 4, 5},
|
||||
{}, {});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_B_scalar) {
|
||||
std::vector<int64_t> A_axes_expected = {};
|
||||
std::vector<int64_t> B_axes_expected = {3, 2, 1, 0};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 3, 4, 5}, {},
|
||||
{}, {3, 2, 1, 0});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_B_vector) {
|
||||
std::vector<int64_t> A_axes_expected = {};
|
||||
std::vector<int64_t> B_axes_expected = {2, 1, 0};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {2, 3, 4, 5}, {5},
|
||||
{}, {2, 1, 0});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_A_bcast_different_size) {
|
||||
std::vector<int64_t> A_axes_expected = {1, 0};
|
||||
std::vector<int64_t> B_axes_expected = {};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {4, 5}, {2, 3, 4, 5},
|
||||
{1, 0}, {});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_both_bcast_different_size) {
|
||||
std::vector<int64_t> A_axes_expected = {1, 0};
|
||||
std::vector<int64_t> B_axes_expected = {3, 2};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {1, 4, 5}, {2, 3, 1, 1},
|
||||
{1, 0}, {3, 2});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_both_bcast_different_size_2) {
|
||||
std::vector<int64_t> A_axes_expected = {0};
|
||||
std::vector<int64_t> B_axes_expected = {3, 2, 1};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {3, 4, 5}, {2, 1, 1, 1},
|
||||
{0}, {3, 2, 1});
|
||||
&A_axes_expected, &B_axes_expected);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_invalid_broadcast) {
|
||||
std::vector<int64_t> A_axes_expected = {};
|
||||
std::vector<int64_t> B_axes_expected = {};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {3, 4, 5}, {2, 1, 6, 1},
|
||||
{}, {}, true /*fail*/);
|
||||
&A_axes_expected, &B_axes_expected, true /*fail*/);
|
||||
}
|
||||
|
||||
TEST(BroadcastGradientArgsTest, Basic_only_A_output) {
|
||||
std::vector<int64_t> A_axes_expected = {0};
|
||||
RunBroadcastGradientArgsTest("BroadcastGradientArgs", {3, 4, 5}, {2, 1, 1, 1},
|
||||
&A_axes_expected, nullptr);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -68,11 +68,18 @@ Status BroadcastGradientArgs<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
Tensor* A_axes = context->Output(0, {static_cast<T>(a_axes.size())});
|
||||
T* A_axes_data = A_axes->template MutableData<T>();
|
||||
std::copy(a_axes.begin(), a_axes.end(), A_axes_data);
|
||||
if (A_axes) { //verify as A_axes is an optional output
|
||||
T* A_axes_data = A_axes->template MutableData<T>();
|
||||
std::copy(a_axes.begin(), a_axes.end(), A_axes_data);
|
||||
}
|
||||
|
||||
Tensor* B_axes = context->Output(1, {static_cast<T>(b_axes.size())});
|
||||
T* B_axes_data = B_axes->template MutableData<T>();
|
||||
std::copy(b_axes.begin(), b_axes.end(), B_axes_data);
|
||||
if (B_axes) { //verify as B_axes is an optional output
|
||||
T* B_axes_data = B_axes->template MutableData<T>();
|
||||
std::copy(b_axes.begin(), b_axes.end(), B_axes_data);
|
||||
}
|
||||
if (!A_axes && !B_axes)
|
||||
LOGS_DEFAULT(WARNING) << "No output found for op BroadcastGradientArgs.";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue