From 1151c661ebc7c50a7366d64d1fc31c7099d8ea65 Mon Sep 17 00:00:00 2001 From: ashari4 <70242157+ashari4@users.noreply.github.com> Date: Sun, 7 Nov 2021 16:04:00 -0800 Subject: [PATCH] Add gi overload (#9690) --- orttraining/orttraining/core/graph/gradient_builder_base.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index c81ec427d9..bd615dc048 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -139,6 +139,12 @@ class GradientBuilderBase { return ArgDef(GradientName(node_->InputDefs()[i]->Name()), node_->InputDefs()[i]->TypeAsProto()); } + // gradient of i-th input of forward op - useful when gradient type does not match input type + ArgDef GI(const size_t i, const TypeProto *type) const { + ORT_ENFORCE(i < node_->InputDefs().size()); + return ArgDef(GradientName(node_->InputDefs()[i]->Name()), type); + } + // gradient of i-th output of forward op ArgDef GO(const size_t i) const { ORT_ENFORCE(i < node_->OutputDefs().size());