Add gi overload (#9690)

This commit is contained in:
ashari4 2021-11-07 16:04:00 -08:00 committed by GitHub
parent c6fddb263f
commit 1151c661eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -139,6 +139,12 @@ class GradientBuilderBase {
return ArgDef(GradientName(node_->InputDefs()[i]->Name()), node_->InputDefs()[i]->TypeAsProto());
}
// gradient of i-th input of forward op - useful when gradient type does not match input type
ArgDef GI(const size_t i, const TypeProto *type) const {
ORT_ENFORCE(i < node_->InputDefs().size());
return ArgDef(GradientName(node_->InputDefs()[i]->Name()), type);
}
// gradient of i-th output of forward op
ArgDef GO(const size_t i) const {
ORT_ENFORCE(i < node_->OutputDefs().size());