diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index b683c61cfe..96c7bcbac3 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -84,29 +84,30 @@ static void GetConsumerNodeInputs(onnxruntime::Graph& graph, std::vector>& fp32_inputs) { std::vector consumer_nodes = graph.GetMutableConsumerNodes(arg->Name()); for (Node* node : consumer_nodes) { - int node_arg_slot = -1; + std::vector node_arg_slots; for (int i = 0; i < static_cast(node->InputDefs().size()); i++) { if (node->InputDefs()[i] == arg) { - node_arg_slot = i; - break; + node_arg_slots.push_back(i); } } - if (node_arg_slot == -1) { + if (node_arg_slots.empty()) { continue; } auto it = fp32_node_args_by_op_type.find(node->OpType()); - if (it != fp32_node_args_by_op_type.cend() && - std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) { - fp32_inputs.push_back({node, node_arg_slot}); - } else { - auto it2 = fp32_node_args_by_node.find(node); - if (it2 != fp32_node_args_by_node.cend() && - std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) { + for (auto node_arg_slot : node_arg_slots) { + if (it != fp32_node_args_by_op_type.cend() && + std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) { fp32_inputs.push_back({node, node_arg_slot}); } else { - mixed_precision_inputs.push_back({node, node_arg_slot}); + auto it2 = fp32_node_args_by_node.find(node); + if (it2 != fp32_node_args_by_node.cend() && + std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) { + fp32_inputs.push_back({node, node_arg_slot}); + } else { + mixed_precision_inputs.push_back({node, node_arg_slot}); + } } } } @@ -475,7 +476,7 @@ static Status HandleFunctionBody(const Function& node_func, ONNX_NAMESPACE::Tens ORT_RETURN_IF_ERROR(TransformConstants(graph, mixed_precision_type)); - // End of stage 1. Update types of intermediate-values and return-values: + // End of stage 1. Update types of intermediate-values and return-values:[ Graph::ResolveOptions options; options.override_types = true; ORT_RETURN_IF_ERROR(graph.Resolve(options));