diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index 10eab0412a..7822d1ec11 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -221,15 +221,15 @@ class GradientBuilderBase { } static NodeDef ZeroConstantNode(int elem_type) { - return ConstantScalarNode(0.0f, "ZeroConstant", elem_type); + return ConstantScalarNode(0.0f, "ZeroConstant_Type" + std::to_string(elem_type), elem_type); } static NodeDef HalfConstantNode(int elem_type) { - return ConstantScalarNode(0.5f, "HalfConstant", elem_type); + return ConstantScalarNode(0.5f, "HalfConstant_Type" + std::to_string(elem_type), elem_type); } static NodeDef OneConstantNode(int elem_type) { - return ConstantScalarNode(1.0f, "OneConstant", elem_type); + return ConstantScalarNode(1.0f, "OneConstant_Type" + std::to_string(elem_type), elem_type); } void HandleBroadcasting(const ArgDef& input_grad,