diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 5984d17cfe..034b4eae02 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -352,9 +352,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { bool transA = static_cast(attributes.at("transA").i()); bool transB = static_cast(attributes.at("transB").i()); - ArgDef A = I(0), B = I(1), C = I(2), dY = GO(0), - dA = GI(0), dB = GI(1), dC = GI(2); - int elem_type = OElemType(0); + ArgDef A = I(0), B = I(1), dY = GO(0), + dA = GI(0), dB = GI(1); AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1)); AttributeProto transpose_second_input = MakeAttribute("transB", int64_t(1)); @@ -431,6 +430,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) { if (IsGradientRequiredForSrcNodeInput(2)) { // Y = beta * C // dC = beta * dY + ArgDef C = I(2), dC = GI(2); + int elem_type = OElemType(0); bool has_beta = attributes.at("beta").has_f(); float beta = attributes.at("beta").f(); ORT_ENFORCE(beta != 0.0f); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index e4b3277835..6bbf6fa28a 100755 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -507,6 +507,13 @@ void RunGemmGradTests(const OpDef& op_def) { GradientChecker gradient_checker; const std::vector attributes = {}; + + // Single Batch no third input + { + gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}}, {{1, 3}}, &max_error, attributes, true, true); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + // Single Batch with Scalar Bias { gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {}}, {{1, 3}}, &max_error, attributes, true, true);