Fix 2 input Gemm grad (#7561)

* Add test for 2 input Gemm grad.

* Fix 2 input Gemm grad.
This commit is contained in:
Sergii Dymchenko 2021-05-04 12:00:14 -07:00 committed by GitHub
parent d812354ebd
commit a647da3e1a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 3 deletions

View file

@ -352,9 +352,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
bool transA = static_cast<bool>(attributes.at("transA").i());
bool transB = static_cast<bool>(attributes.at("transB").i());
ArgDef A = I(0), B = I(1), C = I(2), dY = GO(0),
dA = GI(0), dB = GI(1), dC = GI(2);
int elem_type = OElemType(0);
ArgDef A = I(0), B = I(1), dY = GO(0),
dA = GI(0), dB = GI(1);
AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1));
AttributeProto transpose_second_input = MakeAttribute("transB", int64_t(1));
@ -431,6 +430,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
if (IsGradientRequiredForSrcNodeInput(2)) {
// Y = beta * C
// dC = beta * dY
ArgDef C = I(2), dC = GI(2);
int elem_type = OElemType(0);
bool has_beta = attributes.at("beta").has_f();
float beta = attributes.at("beta").f();
ORT_ENFORCE(beta != 0.0f);

View file

@ -507,6 +507,13 @@ void RunGemmGradTests(const OpDef& op_def) {
GradientChecker<float, float, float> gradient_checker;
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
// Single Batch no third input
{
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}}, {{1, 3}}, &max_error, attributes, true, true);
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
// Single Batch with Scalar Bias
{
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {}}, {{1, 3}}, &max_error, attributes, true, true);