Fix the mix precision handle for square case (#8333)

* handle unsqueeze change in opset13

* fix the node arguments index check for square case (x * x)

* Revert "fix the node arguments index check for square case (x * x)"

This reverts commit c66344f0a82c35d8c24d31f2264cf7e9b235ce22.

* handle the square case (x * x) for node argument search

Co-authored-by: Cheng Tang <chenta@microsoft.com>
This commit is contained in:
Tang, Cheng 2021-07-09 09:24:19 -07:00 committed by GitHub
parent 187743726b
commit 598454bb5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -84,29 +84,30 @@ static void GetConsumerNodeInputs(onnxruntime::Graph& graph,
std::vector<std::pair<Node*, int>>& fp32_inputs) {
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(arg->Name());
for (Node* node : consumer_nodes) {
int node_arg_slot = -1;
std::vector<int> node_arg_slots;
for (int i = 0; i < static_cast<int>(node->InputDefs().size()); i++) {
if (node->InputDefs()[i] == arg) {
node_arg_slot = i;
break;
node_arg_slots.push_back(i);
}
}
if (node_arg_slot == -1) {
if (node_arg_slots.empty()) {
continue;
}
auto it = fp32_node_args_by_op_type.find(node->OpType());
if (it != fp32_node_args_by_op_type.cend() &&
std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) {
fp32_inputs.push_back({node, node_arg_slot});
} else {
auto it2 = fp32_node_args_by_node.find(node);
if (it2 != fp32_node_args_by_node.cend() &&
std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) {
for (auto node_arg_slot : node_arg_slots) {
if (it != fp32_node_args_by_op_type.cend() &&
std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) {
fp32_inputs.push_back({node, node_arg_slot});
} else {
mixed_precision_inputs.push_back({node, node_arg_slot});
auto it2 = fp32_node_args_by_node.find(node);
if (it2 != fp32_node_args_by_node.cend() &&
std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) {
fp32_inputs.push_back({node, node_arg_slot});
} else {
mixed_precision_inputs.push_back({node, node_arg_slot});
}
}
}
}
@ -475,7 +476,7 @@ static Status HandleFunctionBody(const Function& node_func, ONNX_NAMESPACE::Tens
ORT_RETURN_IF_ERROR(TransformConstants(graph, mixed_precision_type));
// End of stage 1. Update types of intermediate-values and return-values:
// End of stage 1. Update types of intermediate-values and return-values:[
Graph::ResolveOptions options;
options.override_types = true;
ORT_RETURN_IF_ERROR(graph.Resolve(options));