This commit is contained in:
linkerzhang 2018-11-29 18:36:05 -08:00
parent fbd7b9b8f0
commit e9de9abc96
2 changed files with 7 additions and 8 deletions

View file

@ -106,12 +106,12 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
if (!node)
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
auto& inputs = node->InputDefs();
auto& inputs = node->MutableInputDefs();
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacement_defs;
bool casted = false;
for (auto input : inputs) {
if (NeedInsertCast(node, input)) {
auto src_arg = const_cast<onnxruntime::NodeArg*>(input);
auto src_arg = input;
if (input_def_updates.count(src_arg)) {
replacement_defs[src_arg] = input_def_updates[src_arg];
} else {
@ -136,7 +136,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
node->SetExecutionProviderType(kCpuExecutionProvider);
}
auto& outputs = node->OutputDefs();
auto& outputs = node->MutableOutputDefs();
for (auto output : outputs) {
// todo: check is the kernel available
// here is based on the assumption that if we cast a cpu op's input from float16 to float
@ -146,7 +146,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
DataTypeImpl::TypeFromProto(*output->TypeAsProto()) == DataTypeImpl::GetTensorType<MLFloat16>() &&
casted) {
//insert cast op to cast output back to float16
auto dst_arg = const_cast<onnxruntime::NodeArg*>(output);
auto dst_arg = output;
auto src_arg = AddCastNode(graph,
id_generator,
dst_arg,
@ -172,7 +172,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
// remove those two cast.
auto src_type = node.InputDefs()[0]->Type();
auto dst_type = node.OutputDefs()[0]->Type();
auto input = node.InputDefs()[0];
auto input = node.MutableInputDefs()[0];
int child_removed = 0;
int num_child = 0;
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
@ -183,8 +183,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
if (src_type == dst_type1 && src_type1 == dst_type) {
//node *it's output's follower could be linked with node's input.
replacement_defs.clear();
replacement_defs[const_cast<onnxruntime::NodeArg*>(output_node.OutputDefs()[0])] =
const_cast<onnxruntime::NodeArg*>(input);
replacement_defs[const_cast<onnxruntime::NodeArg*>(output_node.OutputDefs()[0])] = input;
for (auto next_it = output_node.OutputNodesBegin(); next_it != output_node.OutputNodesEnd(); ++next_it) {
const_cast<onnxruntime::Node*>(&(*next_it))->ReplaceDefs(replacement_defs);
}

View file

@ -78,7 +78,7 @@ Status UnsqueezeElimination::Apply(onnxruntime::Graph& graph, bool& modified) co
auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == output_def) {
def = const_cast<NodeArg*>(input_def);
def = input_def;
}
}
}