diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index c81ec427d9..bd615dc048 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -139,6 +139,12 @@ class GradientBuilderBase { return ArgDef(GradientName(node_->InputDefs()[i]->Name()), node_->InputDefs()[i]->TypeAsProto()); } + // gradient of i-th input of forward op - useful when gradient type does not match input type + ArgDef GI(const size_t i, const TypeProto *type) const { + ORT_ENFORCE(i < node_->InputDefs().size()); + return ArgDef(GradientName(node_->InputDefs()[i]->Name()), type); + } + // gradient of i-th output of forward op ArgDef GO(const size_t i) const { ORT_ENFORCE(i < node_->OutputDefs().size());