mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Fix 2 input Gemm grad (#7561)
* Add test for 2 input Gemm grad. * Fix 2 input Gemm grad.
This commit is contained in:
parent
d812354ebd
commit
a647da3e1a
2 changed files with 11 additions and 3 deletions
|
|
@ -352,9 +352,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
bool transA = static_cast<bool>(attributes.at("transA").i());
|
||||
bool transB = static_cast<bool>(attributes.at("transB").i());
|
||||
|
||||
ArgDef A = I(0), B = I(1), C = I(2), dY = GO(0),
|
||||
dA = GI(0), dB = GI(1), dC = GI(2);
|
||||
int elem_type = OElemType(0);
|
||||
ArgDef A = I(0), B = I(1), dY = GO(0),
|
||||
dA = GI(0), dB = GI(1);
|
||||
AttributeProto transpose_first_input = MakeAttribute("transA", int64_t(1));
|
||||
AttributeProto transpose_second_input = MakeAttribute("transB", int64_t(1));
|
||||
|
||||
|
|
@ -431,6 +430,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
if (IsGradientRequiredForSrcNodeInput(2)) {
|
||||
// Y = beta * C
|
||||
// dC = beta * dY
|
||||
ArgDef C = I(2), dC = GI(2);
|
||||
int elem_type = OElemType(0);
|
||||
bool has_beta = attributes.at("beta").has_f();
|
||||
float beta = attributes.at("beta").f();
|
||||
ORT_ENFORCE(beta != 0.0f);
|
||||
|
|
|
|||
|
|
@ -507,6 +507,13 @@ void RunGemmGradTests(const OpDef& op_def) {
|
|||
GradientChecker<float, float, float> gradient_checker;
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {};
|
||||
|
||||
|
||||
// Single Batch no third input
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}}, {{1, 3}}, &max_error, attributes, true, true);
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// Single Batch with Scalar Bias
|
||||
{
|
||||
gradient_checker.ComputeGradientError(op_def, {{1, 4}, {4, 3}, {}}, {{1, 3}}, &max_error, attributes, true, true);
|
||||
|
|
|
|||
Loading…
Reference in a new issue