mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
update
This commit is contained in:
parent
fbd7b9b8f0
commit
e9de9abc96
2 changed files with 7 additions and 8 deletions
|
|
@ -106,12 +106,12 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
|
|||
if (!node)
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
|
||||
auto& inputs = node->InputDefs();
|
||||
auto& inputs = node->MutableInputDefs();
|
||||
std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> replacement_defs;
|
||||
bool casted = false;
|
||||
for (auto input : inputs) {
|
||||
if (NeedInsertCast(node, input)) {
|
||||
auto src_arg = const_cast<onnxruntime::NodeArg*>(input);
|
||||
auto src_arg = input;
|
||||
if (input_def_updates.count(src_arg)) {
|
||||
replacement_defs[src_arg] = input_def_updates[src_arg];
|
||||
} else {
|
||||
|
|
@ -136,7 +136,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
|
|||
node->SetExecutionProviderType(kCpuExecutionProvider);
|
||||
}
|
||||
|
||||
auto& outputs = node->OutputDefs();
|
||||
auto& outputs = node->MutableOutputDefs();
|
||||
for (auto output : outputs) {
|
||||
// todo: check is the kernel available
|
||||
// here is based on the assumption that if we cast a cpu op's input from float16 to float
|
||||
|
|
@ -146,7 +146,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
|
|||
DataTypeImpl::TypeFromProto(*output->TypeAsProto()) == DataTypeImpl::GetTensorType<MLFloat16>() &&
|
||||
casted) {
|
||||
//insert cast op to cast output back to float16
|
||||
auto dst_arg = const_cast<onnxruntime::NodeArg*>(output);
|
||||
auto dst_arg = output;
|
||||
auto src_arg = AddCastNode(graph,
|
||||
id_generator,
|
||||
dst_arg,
|
||||
|
|
@ -172,7 +172,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
|
|||
// remove those two cast.
|
||||
auto src_type = node.InputDefs()[0]->Type();
|
||||
auto dst_type = node.OutputDefs()[0]->Type();
|
||||
auto input = node.InputDefs()[0];
|
||||
auto input = node.MutableInputDefs()[0];
|
||||
int child_removed = 0;
|
||||
int num_child = 0;
|
||||
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
|
||||
|
|
@ -183,8 +183,7 @@ Status InsertCastTransformer::Apply(onnxruntime::Graph& graph, bool& modified) c
|
|||
if (src_type == dst_type1 && src_type1 == dst_type) {
|
||||
//node *it's output's follower could be linked with node's input.
|
||||
replacement_defs.clear();
|
||||
replacement_defs[const_cast<onnxruntime::NodeArg*>(output_node.OutputDefs()[0])] =
|
||||
const_cast<onnxruntime::NodeArg*>(input);
|
||||
replacement_defs[const_cast<onnxruntime::NodeArg*>(output_node.OutputDefs()[0])] = input;
|
||||
for (auto next_it = output_node.OutputNodesBegin(); next_it != output_node.OutputNodesEnd(); ++next_it) {
|
||||
const_cast<onnxruntime::Node*>(&(*next_it))->ReplaceDefs(replacement_defs);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ Status UnsqueezeElimination::Apply(onnxruntime::Graph& graph, bool& modified) co
|
|||
auto& input_defs = output_node->MutableInputDefs();
|
||||
for (auto& def : input_defs) {
|
||||
if (def == output_def) {
|
||||
def = const_cast<NodeArg*>(input_def);
|
||||
def = input_def;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue