mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Add gi overload (#9690)
This commit is contained in:
parent
c6fddb263f
commit
1151c661eb
1 changed files with 6 additions and 0 deletions
|
|
@ -139,6 +139,12 @@ class GradientBuilderBase {
|
|||
return ArgDef(GradientName(node_->InputDefs()[i]->Name()), node_->InputDefs()[i]->TypeAsProto());
|
||||
}
|
||||
|
||||
// gradient of i-th input of forward op - useful when gradient type does not match input type
|
||||
ArgDef GI(const size_t i, const TypeProto *type) const {
|
||||
ORT_ENFORCE(i < node_->InputDefs().size());
|
||||
return ArgDef(GradientName(node_->InputDefs()[i]->Name()), type);
|
||||
}
|
||||
|
||||
// gradient of i-th output of forward op
|
||||
ArgDef GO(const size_t i) const {
|
||||
ORT_ENFORCE(i < node_->OutputDefs().size());
|
||||
|
|
|
|||
Loading…
Reference in a new issue