mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Fix the mix precision handle for square case (#8333)
* handle unsqueeze change in opset13 * fix the node arguments index check for square case (x * x) * Revert "fix the node arguments index check for square case (x * x)" This reverts commit c66344f0a82c35d8c24d31f2264cf7e9b235ce22. * handle the square case (x * x) for node argument search Co-authored-by: Cheng Tang <chenta@microsoft.com>
This commit is contained in:
parent
187743726b
commit
598454bb5f
1 changed files with 14 additions and 13 deletions
|
|
@ -84,29 +84,30 @@ static void GetConsumerNodeInputs(onnxruntime::Graph& graph,
|
|||
std::vector<std::pair<Node*, int>>& fp32_inputs) {
|
||||
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(arg->Name());
|
||||
for (Node* node : consumer_nodes) {
|
||||
int node_arg_slot = -1;
|
||||
std::vector<int> node_arg_slots;
|
||||
for (int i = 0; i < static_cast<int>(node->InputDefs().size()); i++) {
|
||||
if (node->InputDefs()[i] == arg) {
|
||||
node_arg_slot = i;
|
||||
break;
|
||||
node_arg_slots.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (node_arg_slot == -1) {
|
||||
if (node_arg_slots.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto it = fp32_node_args_by_op_type.find(node->OpType());
|
||||
if (it != fp32_node_args_by_op_type.cend() &&
|
||||
std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) {
|
||||
fp32_inputs.push_back({node, node_arg_slot});
|
||||
} else {
|
||||
auto it2 = fp32_node_args_by_node.find(node);
|
||||
if (it2 != fp32_node_args_by_node.cend() &&
|
||||
std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) {
|
||||
for (auto node_arg_slot : node_arg_slots) {
|
||||
if (it != fp32_node_args_by_op_type.cend() &&
|
||||
std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) {
|
||||
fp32_inputs.push_back({node, node_arg_slot});
|
||||
} else {
|
||||
mixed_precision_inputs.push_back({node, node_arg_slot});
|
||||
auto it2 = fp32_node_args_by_node.find(node);
|
||||
if (it2 != fp32_node_args_by_node.cend() &&
|
||||
std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) {
|
||||
fp32_inputs.push_back({node, node_arg_slot});
|
||||
} else {
|
||||
mixed_precision_inputs.push_back({node, node_arg_slot});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -475,7 +476,7 @@ static Status HandleFunctionBody(const Function& node_func, ONNX_NAMESPACE::Tens
|
|||
|
||||
ORT_RETURN_IF_ERROR(TransformConstants(graph, mixed_precision_type));
|
||||
|
||||
// End of stage 1. Update types of intermediate-values and return-values:
|
||||
// End of stage 1. Update types of intermediate-values and return-values:[
|
||||
Graph::ResolveOptions options;
|
||||
options.override_types = true;
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve(options));
|
||||
|
|
|
|||
Loading…
Reference in a new issue