diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index 4b00f4c0a6..09cc5a9fca 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -31,11 +31,9 @@ namespace training { // The following is a list of ops, as well as functions, that will // continue to use 32-bit precision. Others will used reduced precision. -static const std::unordered_set FP32_Nodes = { - "SparseSoftmaxCrossEntropy", - "SparseSoftmaxCrossEntropyGrad", - "SoftmaxCrossEntropyLoss", - "SoftmaxCrossEntropyLossGrad"}; +// Loss Ops and loss grad Ops are now handled by LossSubgraph, so currently this set is empty. +// If in the future there is new FP32 Op, we can add it here without changing code on other place. +static const std::unordered_set FP32_Nodes = {}; bool IsFP32Node(const Node* node) { return FP32_Nodes.find(node->OpType()) != FP32_Nodes.cend(); @@ -49,15 +47,13 @@ static const std::unordered_map> stage1_fp32_node_ {"DropoutGrad", {2}}, }; +// Currently the list here is same as stage1 above due to empty FP32_Nodes. +// It's possibile we will have more FP32 nodes added, this map will also be extended. static const std::unordered_map> stage2_fp32_node_args = { {"TrainableDropout", {1}}, {"TrainableDropoutGrad", {2}}, {"Dropout", {1}}, {"DropoutGrad", {2}}, - {"SparseSoftmaxCrossEntropy", {0, 2}}, - {"SparseSoftmaxCrossEntropyGrad", {0, 1, 3}}, - {"SoftmaxCrossEntropyLoss", {0, 2}}, - {"SoftmaxCrossEntropyLossGrad", {0, 1, 3}}, }; bool IsFP32(const std::unordered_map>& map, std::string opname, int argnum) { @@ -70,10 +66,30 @@ bool IsFP32(const std::unordered_map>& map, std::s } } +static const std::string loss_scale_input = "loss_scale"; + +static const std::unordered_set loss_subgraph_entry_nodes = { + "SparseSoftmaxCrossEntropy", + "SoftmaxCrossEntropyLoss"}; + +static const std::unordered_set loss_subgraph_exit_nodes = { + "SparseSoftmaxCrossEntropyGrad", + "SoftmaxCrossEntropyLossGrad"}; + +static bool IsLossSubgraphEntryNode(const Node* node) { + return loss_subgraph_entry_nodes.find(node->OpType()) != loss_subgraph_entry_nodes.cend(); +} + +static bool IsLossSubgraphExitNode(const Node* node) { + return loss_subgraph_exit_nodes.find(node->OpType()) != loss_subgraph_exit_nodes.cend(); +} + // Separate the consumer nodes of `arg` into two groups: FP32 vs FP16 -// The argument `fp32_node_args` specifies the cases where the `arg` should be 32-bit float. +// The argument `fp32_node_args_by_op_type` specifies the cases where the `arg` should be 32-bit float using op type. +// The argument `fp32_node_args_by_node` specifies the cases where the `arg` should be 32-bit float using node pointer. static void GetConsumerNodeInputs(onnxruntime::Graph& graph, - const std::unordered_map>& fp32_node_args, + const std::unordered_map>& fp32_node_args_by_op_type, + const std::unordered_map>& fp32_node_args_by_node, const NodeArg* arg, std::vector>& fp16_inputs, std::vector>& fp32_inputs) { @@ -91,15 +107,17 @@ static void GetConsumerNodeInputs(onnxruntime::Graph& graph, continue; } - auto it = fp32_node_args.find(node->OpType()); - if (it == fp32_node_args.cend()) { - fp16_inputs.push_back({node, node_arg_slot}); + auto it = fp32_node_args_by_op_type.find(node->OpType()); + if (it != fp32_node_args_by_op_type.cend() && + std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) { + fp32_inputs.push_back({node, node_arg_slot}); } else { - const auto index_it = std::find(it->second.cbegin(), it->second.cend(), node_arg_slot); - if (index_it == it->second.cend()) { - fp16_inputs.push_back({node, node_arg_slot}); - } else { + auto it2 = fp32_node_args_by_node.find(node); + if (it2 != fp32_node_args_by_node.cend() && + std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) { fp32_inputs.push_back({node, node_arg_slot}); + } else { + fp16_inputs.push_back({node, node_arg_slot}); } } } @@ -120,9 +138,11 @@ static void RewireCastedNodeArg(onnxruntime::Graph& graph, } // This function tries casting `arg` to `element_type`. -// The argument `fp32_node_args` specifies the cases where the `arg` should be 32-bit float. +// The argument `fp32_node_args_by_op_type` specifies the cases where the `arg` should be 32-bit float using op type. +// The argument `fp32_node_args_by_node` specifies the cases where the `arg` should be 32-bit float using node pointer. static Status CastNodeArg(onnxruntime::Graph& graph, - const std::unordered_map>& fp32_node_args, + const std::unordered_map>& fp32_node_args_by_op_type, + const std::unordered_map>& fp32_node_args_by_node, NodeArg* arg, ONNX_NAMESPACE::TensorProto_DataType elem_type) { if (arg == nullptr) { @@ -135,7 +155,7 @@ static Status CastNodeArg(onnxruntime::Graph& graph, // Get consumer nodes of the input `arg` std::vector> fp16_inputs; std::vector> fp32_inputs; - GetConsumerNodeInputs(graph, fp32_node_args, arg, fp16_inputs, fp32_inputs); + GetConsumerNodeInputs(graph, fp32_node_args_by_op_type, fp32_node_args_by_node, arg, fp16_inputs, fp32_inputs); if ((elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && fp16_inputs.empty()) || (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT && fp32_inputs.empty())) { return Status::OK(); @@ -216,24 +236,207 @@ static Status CastNodeArg(onnxruntime::Graph& graph, return Status::OK(); } -Status TransformConstants(Graph& graph) { +struct LossSubgraph { + // All nodes belong to this subgraph. + std::unordered_set nodes_; + + // NodeArgs that are inputs of this subgraph from outside, which need to be converted to FP32. + std::unordered_set to_fp32_inputs_; + + // NodeArgs that are outputs of this subgraph to outside, which need to be converted to FP16. + std::unordered_set to_fp16_outputs_; + + // Nodes that take float input from outside of subgraph, the input indices are also saved. + // It's useful when calling CastNodeArg, so FP32 inputs will no need to be converted. + std::unordered_map> fp32_node_args_; + + LossSubgraph(Graph& graph) { + GraphViewer graph_viewer(graph); + const auto& order = graph_viewer.GetNodesInTopologicalOrder(); + + // Get the nodes related to loss scale. It's a Mul node and it's grad nodes. + // We initialize loss subgraph only when there is loss scale as input. + std::vector loss_scale_consumers = graph.GetMutableConsumerNodes(loss_scale_input); + if (loss_scale_consumers.size() == 0) { + return; + } + + nodes_.insert(loss_scale_consumers.begin(), loss_scale_consumers.end()); + for (Node* node : loss_scale_consumers) { + for (const NodeArg* output : node->OutputDefs()) { + std::vector level2_consumers = graph.GetMutableConsumerNodes(output->Name()); + nodes_.insert(level2_consumers.begin(), level2_consumers.end()); + } + } + + // The node number here depends on how to implement the gradient of Mul. + // Add this check here for safety at certain level. + ORT_ENFORCE(nodes_.size() == 3, + "The node number of the loss scale and it's grad subgraph is expected to be 3."); + + // Check if graph contains any loss Op from the white-list. + // If not, then above loss scale related nodes are all we need. + bool has_loss_subgraph_entry_node = false; + for (auto index : order) { + if (IsLossSubgraphEntryNode(graph.GetNode(index))) { + has_loss_subgraph_entry_node = true; + break; + } + } + + // If it contains one or more loss Ops from white-list, travel the graph again to get the whole loss subgraph. + if (has_loss_subgraph_entry_node) { + for (auto index : order) { + Node* node = graph.GetNode(index); + if (IsLossSubgraphEntryNode(node) || IsLossSubgraphExitNode(node)) { + nodes_.insert(node); + } else { + // For other nodes, if it consumes any output of any node from loss subgraph, it also belongs to loss subgraph. + bool part_of_loss_subgraph = false; + for (NodeArg* input : node->MutableInputDefs()) { + Node* producer_node = graph.GetMutableProducerNode(input->Name()); + if (producer_node != nullptr && + !IsLossSubgraphExitNode(producer_node) && + nodes_.find(producer_node) != nodes_.cend()) { + part_of_loss_subgraph = true; + break; + } + } + + if (part_of_loss_subgraph) { + nodes_.insert(node); + } + } + } + } + + // We now have all the nodes of the loss subgraph. Now get all float inputs from outside. + for (Node* node : nodes_) { + int index = 0; + for (NodeArg* input : node->MutableInputDefs()) { + if (input->Name() != loss_scale_input && // loss_scale input will keep FP32, no need to handle here. + input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // If its producer is from outside, it's one of the inputs of this subgraph. + Node* producer_node = graph.GetMutableProducerNode(input->Name()); + if (producer_node == nullptr || nodes_.find(producer_node) == nodes_.cend()) { + to_fp32_inputs_.insert(input); + if (fp32_node_args_.find(node) == fp32_node_args_.cend()) { + fp32_node_args_[node] = {index}; + } else { + fp32_node_args_[node].push_back(index); + } + } + } + + index++; + } + + // Get all float outputs to outside of the subgraph. They will be converted to FP16. + for (NodeArg* output : node->MutableOutputDefs()) { + if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + !ContainsAllConsumers(graph, output->Name())) { + to_fp16_outputs_.insert(output); + } + } + } + } + + bool Contains(Node* node) { + return nodes_.find(node) != nodes_.cend(); + } + + // Check if this loss subgraph contains all the consumers of given Arg. + bool ContainsAllConsumers(Graph& graph, const std::string arg_name) { + std::vector consumer_nodes = graph.GetMutableConsumerNodes(arg_name); + for (Node* node : consumer_nodes) { + if (nodes_.find(node) == nodes_.cend()) { + return false; + } + } + + return true; + } + + // For those inputs and constants that are already handled, remove them from the to_fp32 list. + void RemoveFromToFP32Inputs(const std::string& arg_name) { + auto it = to_fp32_inputs_.begin(); + while (it != to_fp32_inputs_.end()) { + if ((*it)->Name() == arg_name) { + it = to_fp32_inputs_.erase(it); + } else { + ++it; + } + } + } + + std::unordered_map>& GetFP32NodeArgs() { + return fp32_node_args_; + } + + // Once all inputs, constants, and function calls are handled, it's time to convert all + // inputs to FP32, and convert all outputs to FP16. + Status ProcessInputsAndOutputs(Graph& graph) { + for (auto* node_arg : to_fp32_inputs_) { + ORT_RETURN_IF_ERROR(CastNodeArg(graph, + stage1_fp32_node_args, + fp32_node_args_, + node_arg, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + } + + for (auto* node_arg : to_fp16_outputs_) { + ORT_RETURN_IF_ERROR(CastNodeArg(graph, + stage1_fp32_node_args, + fp32_node_args_, + node_arg, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + } + + return Status::OK(); + } +}; + +Status TransformConstants(Graph& graph, LossSubgraph* p_loss_subgraph = nullptr) { // This pass does not require topological sort order: okay to visit nodes in any order. // We identify nodeargs to be converted to FP16 first, and then convert them separately // to avoid modifying the graph while iterating through it. std::unordered_set toFP16; for (auto& node : graph.Nodes()) { + // Ignore any node in loss subgraph. + if (p_loss_subgraph != nullptr && p_loss_subgraph->Contains(&node)) { + continue; + } + const std::string& optype = node.OpType(); // TODO: Why do we need to handle "Cast" here? if ((optype == "Constant") || (optype == "Cast") || (optype == "ConstantOfShape")) { for (NodeArg* output : node.MutableOutputDefs()) { - if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) - toFP16.insert(output); + if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // If all consumers are from loss subgraph, don't convert it. + if (p_loss_subgraph == nullptr || !p_loss_subgraph->ContainsAllConsumers(graph, output->Name())) { + toFP16.insert(output); + } + + if (p_loss_subgraph != nullptr) { + // If it's one of loss subgraph's input, remove it from the to-convert set since it's already handled. + p_loss_subgraph->RemoveFromToFP32Inputs(output->Name()); + } + } } } } + for (auto* tensor : toFP16) { - ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage1_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + ORT_RETURN_IF_ERROR( + CastNodeArg(graph, + stage1_fp32_node_args, + p_loss_subgraph != nullptr ? + p_loss_subgraph->GetFP32NodeArgs() : + std::unordered_map>(), + tensor, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); } + return Status::OK(); } @@ -241,7 +444,8 @@ Status TransformConstants(Graph& graph) { // as SparseSoftmaxCrossEntropy where FP32 precision is required. // Converts fp16 tensor --> Op --> fp16 tensor to // fp16 tensor --> Cast --> fp32 tensor --> Op --> fp32 tensor --> Cast --> fp16 tensor -Status TransformStage2(Graph& graph) { +Status TransformStage2(Graph& graph, + const std::unordered_map>& loss_subgraph_fp32_node_args = {}) { // This pass does not require topological sort order: okay to visit nodes in any order. std::unordered_set toFP16, toFP32; for (auto& node : graph.Nodes()) { @@ -260,13 +464,21 @@ Status TransformStage2(Graph& graph) { } } for (auto* tensor : toFP32) - ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage2_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + ORT_RETURN_IF_ERROR(CastNodeArg(graph, + stage2_fp32_node_args, + loss_subgraph_fp32_node_args, + tensor, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); for (auto* tensor : toFP16) - ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage2_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + ORT_RETURN_IF_ERROR(CastNodeArg(graph, + stage2_fp32_node_args, + loss_subgraph_fp32_node_args, + tensor, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); return Status::OK(); } -static Status HandleFunctionCalls(Graph& graph); +static Status HandleFunctionCalls(Graph& graph, LossSubgraph* p_loss_subgraph = nullptr); // TODO: Ideally, we should not need to transform a function-body here. // Ideally, for any full-precision function F, there should be a corresponding 16-bit precision @@ -304,7 +516,11 @@ static Status HandleFunctionBody(const Function& node_func) { // Introduce cast to full-precision if required: // TODO: fix const_cast; Graph doesn't provide us a method "GetMutableInputs". NodeArg* mutable_input = const_cast(input); - CastNodeArg(graph, stage1_fp32_node_args, mutable_input, ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + CastNodeArg(graph, + stage1_fp32_node_args, + std::unordered_map>(), + mutable_input, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT); } } @@ -329,24 +545,30 @@ static Status HandleFunctionBody(const Function& node_func) { return status; } -static Status HandleFunctionCalls(Graph& graph) { +static Status HandleFunctionCalls(Graph& graph, LossSubgraph* p_loss_subgraph) { GraphViewer graph_viewer(graph); const auto& order = graph_viewer.GetNodesInTopologicalOrder(); for (auto index : order) { Node* node = graph.GetNode(index); - if (!IsFP32Node(node)) { // Bodies of FP32 Functions are not transformed - const Function* node_func = node->GetFunctionBody(); - if (nullptr != node_func) { - ORT_RETURN_IF_ERROR(HandleFunctionBody(*node_func)); - } + // Bodies of FP32 Functions are not transformed. + if (IsFP32Node(node) || + (p_loss_subgraph != nullptr && p_loss_subgraph->Contains(node))) { + continue; + } + + const Function* node_func = node->GetFunctionBody(); + if (nullptr != node_func) { + ORT_RETURN_IF_ERROR(HandleFunctionBody(*node_func)); } } + return Status::OK(); } // Create FP16 NodeArg and update the consumers of arg with new FP16 NodeArg. static NodeArg* CreateFP16NodeArgAndUpdateConsumers(Graph& graph, - const std::unordered_map>& fp32_node_args, + const std::unordered_map>& fp32_node_args_by_op_type, + const std::unordered_map>& fp32_node_args_by_node, const NodeArg* arg) { ORT_ENFORCE(arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "data type is not float"); @@ -360,7 +582,7 @@ static NodeArg* CreateFP16NodeArgAndUpdateConsumers(Graph& graph, // Check consumer nodes std::vector> fp16_inputs; std::vector> fp32_inputs; - GetConsumerNodeInputs(graph, fp32_node_args, arg, fp16_inputs, fp32_inputs); + GetConsumerNodeInputs(graph, fp32_node_args_by_op_type, fp32_node_args_by_node, arg, fp16_inputs, fp32_inputs); if (fp16_inputs.empty()) { return nullptr; } @@ -376,6 +598,9 @@ Status TransformGraphForMixedPrecision(Graph& graph, const std::unordered_set& weights_to_train, bool use_fp16_initializer, std::unordered_map& fp32_weight_name_to_fp16_node_arg) { + // Stag 0: Initialize loss subgraph. + LossSubgraph loss_subgraph(graph); + // Stage 1: Convert whole graph including forward and backward to FP16 // Initialize function body for all function nodes // This is required to make sure after converting inputs\weights to FP16 @@ -385,10 +610,23 @@ Status TransformGraphForMixedPrecision(Graph& graph, } // Insert Cast node to convert inputs from FP32 to FP16 + // If all consumers are from loss graph, don't convert it, and remove it from To-32 loss graph inputs. for (const NodeArg* input : graph.GetInputs()) { - if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - ORT_RETURN_IF_ERROR( - CastNodeArg(graph, stage1_fp32_node_args, graph.GetNodeArg(input->Name()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + // Input loss_scale will always keep as FP32. + if (input->Name() != loss_scale_input && + input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // If all consumers are from loss subgraph, no need to convert. + if (!loss_subgraph.ContainsAllConsumers(graph, input->Name())) { + ORT_RETURN_IF_ERROR( + CastNodeArg(graph, + stage1_fp32_node_args, + loss_subgraph.GetFP32NodeArgs(), + graph.GetNodeArg(input->Name()), + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + } + + // Remove it from the to-convert set since it's already handled. + loss_subgraph.RemoveFromToFP32Inputs(input->Name()); } } @@ -399,18 +637,31 @@ Status TransformGraphForMixedPrecision(Graph& graph, for (const auto& kv : initialized_tensors) { NodeArg* input = graph.GetNodeArg(kv.first); if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - if (use_fp16_initializer) { - NodeArg* fp16_weight_arg = CreateFP16NodeArgAndUpdateConsumers(graph, stage1_fp32_node_args, input); - if (fp16_weight_arg != nullptr) { - fp16_initializers.emplace_back(fp16_weight_arg->Name(), kv.second); - const auto it = weights_to_train.find(kv.first); - if (it != weights_to_train.cend()) { - fp32_weight_name_to_fp16_node_arg_result[kv.first] = fp16_weight_arg; + // If all consumers are from loss graph, don't convert it. + if (!loss_subgraph.ContainsAllConsumers(graph, input->Name())) { + if (use_fp16_initializer) { + NodeArg* fp16_weight_arg = CreateFP16NodeArgAndUpdateConsumers(graph, + stage1_fp32_node_args, + loss_subgraph.GetFP32NodeArgs(), + input); + if (fp16_weight_arg != nullptr) { + fp16_initializers.emplace_back(fp16_weight_arg->Name(), kv.second); + const auto it = weights_to_train.find(kv.first); + if (it != weights_to_train.cend()) { + fp32_weight_name_to_fp16_node_arg_result[kv.first] = fp16_weight_arg; + } } + } else { + ORT_RETURN_IF_ERROR(CastNodeArg(graph, + stage1_fp32_node_args, + loss_subgraph.GetFP32NodeArgs(), + input, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); } - } else { - ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage1_fp32_node_args, input, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); } + + // Remove it from the to-convert set since it's already handled. + loss_subgraph.RemoveFromToFP32Inputs(input->Name()); } } @@ -426,7 +677,7 @@ Status TransformGraphForMixedPrecision(Graph& graph, for (auto& node : graph.Nodes()) { // For send and recv node, if the tensor being sent or received is FP32, update its // attribute and change it to FP16. - if (!node.OpType().compare("Send") || !node.OpType().compare("Recv")) { + if ((!node.OpType().compare("Send") || !node.OpType().compare("Recv")) && !loss_subgraph.Contains(&node)) { auto& attributes = node.GetMutableAttributes(); auto* element_type = &(attributes.find("element_types")->second); int ints_size = element_type->ints_size(); @@ -441,10 +692,13 @@ Status TransformGraphForMixedPrecision(Graph& graph, } // Handle implicit data type casting nodes such as Cast, ConstantOfShape - ORT_RETURN_IF_ERROR(TransformConstants(graph)); + ORT_RETURN_IF_ERROR(TransformConstants(graph, &loss_subgraph)); // Handle function body - ORT_RETURN_IF_ERROR(HandleFunctionCalls(graph)); + ORT_RETURN_IF_ERROR(HandleFunctionCalls(graph, &loss_subgraph)); + + // Handle loss graph inputs and outputs. + ORT_RETURN_IF_ERROR(loss_subgraph.ProcessInputsAndOutputs(graph)); // At this point, the model has been transformed to a valid FP16 model. @@ -454,7 +708,7 @@ Status TransformGraphForMixedPrecision(Graph& graph, ORT_RETURN_IF_ERROR(graph.Resolve(options)); - TransformStage2(graph); + TransformStage2(graph, loss_subgraph.GetFP32NodeArgs()); ORT_RETURN_IF_ERROR(graph.Resolve(options)); diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index dd1e19980b..6b650eed47 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -794,7 +794,15 @@ const DataTransferManager& TrainingSession::GetDataTransferManager() const { bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) const { auto output_producer_node = model_->MainGraph().GetProducerNode(output_name); ORT_ENFORCE(output_producer_node != nullptr, "Output: " + output_name + " is not produced by any node."); - return IsFP32Node(output_producer_node); + + for (auto output : output_producer_node->OutputDefs()) { + if (output->Name() == output_name && output->TypeAsProto() != nullptr && output->TypeAsProto()->has_tensor_type() + && output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return true; + } + } + + return false; } common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) {