mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
bug fix for name of gradient constant (#5626)
Co-authored-by: Vincent Wang <weicwang@AiFramework2080ti2.corp.microsoft.com>
This commit is contained in:
parent
b4869926d3
commit
1fa1c51544
1 changed files with 3 additions and 3 deletions
|
|
@ -221,15 +221,15 @@ class GradientBuilderBase {
|
|||
}
|
||||
|
||||
static NodeDef ZeroConstantNode(int elem_type) {
|
||||
return ConstantScalarNode(0.0f, "ZeroConstant", elem_type);
|
||||
return ConstantScalarNode(0.0f, "ZeroConstant_Type" + std::to_string(elem_type), elem_type);
|
||||
}
|
||||
|
||||
static NodeDef HalfConstantNode(int elem_type) {
|
||||
return ConstantScalarNode(0.5f, "HalfConstant", elem_type);
|
||||
return ConstantScalarNode(0.5f, "HalfConstant_Type" + std::to_string(elem_type), elem_type);
|
||||
}
|
||||
|
||||
static NodeDef OneConstantNode(int elem_type) {
|
||||
return ConstantScalarNode(1.0f, "OneConstant", elem_type);
|
||||
return ConstantScalarNode(1.0f, "OneConstant_Type" + std::to_string(elem_type), elem_type);
|
||||
}
|
||||
|
||||
void HandleBroadcasting(const ArgDef& input_grad,
|
||||
|
|
|
|||
Loading…
Reference in a new issue