From e9de9abc961e2ea09ac0550239cd8c8f083e6239 Mon Sep 17 00:00:00 2001 From: linkerzhang Date: Thu, 29 Nov 2018 18:36:05 -0800 Subject: [PATCH] update --- .../core/framework/insert_cast_transformer.cc | 13 ++++++------- onnxruntime/core/graph/unsqueeze_elimination.cc | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/framework/insert_cast_transformer.cc b/onnxruntime/core/framework/insert_cast_transformer.cc index cf46bb5705..eb875f4218 100644 --- a/onnxruntime/core/framework/insert_cast_transformer.cc +++ b/onnxruntime/core/framework/insert_cast_transformer.cc @@ -106,12 +106,12 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c if (!node) return Status(ONNXRUNTIME, INVALID_ARGUMENT); - auto& inputs = node->InputDefs(); + auto& inputs = node->MutableInputDefs(); std::map replacement_defs; bool casted = false; for (auto input : inputs) { if (NeedInsertCast(node, input)) { - auto src_arg = const_cast(input); + auto src_arg = input; if (input_def_updates.count(src_arg)) { replacement_defs[src_arg] = input_def_updates[src_arg]; } else { @@ -136,7 +136,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c node->SetExecutionProviderType(kCpuExecutionProvider); } - auto& outputs = node->OutputDefs(); + auto& outputs = node->MutableOutputDefs(); for (auto output : outputs) { // todo: check is the kernel available // here is based on the assumption that if we cast a cpu op's input from float16 to float @@ -146,7 +146,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c DataTypeImpl::TypeFromProto(*output->TypeAsProto()) == DataTypeImpl::GetTensorType() && casted) { //insert cast op to cast output back to float16 - auto dst_arg = const_cast(output); + auto dst_arg = output; auto src_arg = AddCastNode(graph, id_generator, dst_arg, @@ -172,7 +172,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c // remove those two cast. auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - auto input = node.InputDefs()[0]; + auto input = node.MutableInputDefs()[0]; int child_removed = 0; int num_child = 0; for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { @@ -183,8 +183,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c if (src_type == dst_type1 && src_type1 == dst_type) { //node *it's output's follower could be linked with node's input. replacement_defs.clear(); - replacement_defs[const_cast(output_node.OutputDefs()[0])] = - const_cast(input); + replacement_defs[const_cast(output_node.OutputDefs()[0])] = input; for (auto next_it = output_node.OutputNodesBegin(); next_it != output_node.OutputNodesEnd(); ++next_it) { const_cast(&(*next_it))->ReplaceDefs(replacement_defs); } diff --git a/onnxruntime/core/graph/unsqueeze_elimination.cc b/onnxruntime/core/graph/unsqueeze_elimination.cc index f7b32f9643..31b0347c24 100644 --- a/onnxruntime/core/graph/unsqueeze_elimination.cc +++ b/onnxruntime/core/graph/unsqueeze_elimination.cc @@ -78,7 +78,7 @@ Status UnsqueezeElimination::Apply(onnxruntime::Graph& graph, bool& modified) co auto& input_defs = output_node->MutableInputDefs(); for (auto& def : input_defs) { if (def == output_def) { - def = const_cast(input_def); + def = input_def; } } }