From 598454bb5fceff02800997262a0599522fb9c2d8 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Fri, 9 Jul 2021 09:24:19 -0700 Subject: [PATCH] Fix the mix precision handle for square case (#8333) * handle unsqueeze change in opset13 * fix the node arguments index check for square case (x * x) * Revert "fix the node arguments index check for square case (x * x)" This reverts commit c66344f0a82c35d8c24d31f2264cf7e9b235ce22. * handle the square case (x * x) for node argument search Co-authored-by: Cheng Tang --- .../core/graph/mixed_precision_transformer.cc | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index b683c61cfe..96c7bcbac3 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -84,29 +84,30 @@ static void GetConsumerNodeInputs(onnxruntime::Graph& graph, std::vector>& fp32_inputs) { std::vector consumer_nodes = graph.GetMutableConsumerNodes(arg->Name()); for (Node* node : consumer_nodes) { - int node_arg_slot = -1; + std::vector node_arg_slots; for (int i = 0; i < static_cast(node->InputDefs().size()); i++) { if (node->InputDefs()[i] == arg) { - node_arg_slot = i; - break; + node_arg_slots.push_back(i); } } - if (node_arg_slot == -1) { + if (node_arg_slots.empty()) { continue; } auto it = fp32_node_args_by_op_type.find(node->OpType()); - if (it != fp32_node_args_by_op_type.cend() && - std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) { - fp32_inputs.push_back({node, node_arg_slot}); - } else { - auto it2 = fp32_node_args_by_node.find(node); - if (it2 != fp32_node_args_by_node.cend() && - std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) { + for (auto node_arg_slot : node_arg_slots) { + if (it != fp32_node_args_by_op_type.cend() && + std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) { fp32_inputs.push_back({node, node_arg_slot}); } else { - mixed_precision_inputs.push_back({node, node_arg_slot}); + auto it2 = fp32_node_args_by_node.find(node); + if (it2 != fp32_node_args_by_node.cend() && + std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) { + fp32_inputs.push_back({node, node_arg_slot}); + } else { + mixed_precision_inputs.push_back({node, node_arg_slot}); + } } } } @@ -475,7 +476,7 @@ static Status HandleFunctionBody(const Function& node_func, ONNX_NAMESPACE::Tens ORT_RETURN_IF_ERROR(TransformConstants(graph, mixed_precision_type)); - // End of stage 1. Update types of intermediate-values and return-values: + // End of stage 1. Update types of intermediate-values and return-values:[ Graph::ResolveOptions options; options.override_types = true; ORT_RETURN_IF_ERROR(graph.Resolve(options));