Keep loss_scale and Whole Loss Subgraph in FP32 during Mixed Precision Training (#4268)

* Keep loss subgraph as FP32 when mixed-p training.

* Fix case where there is no white-list loss op.

* Get nodes from loss_scale instead of whitelist.

* rename const variables.

Co-authored-by: Vincent Wang <weicwang@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Vincent Wang 2020-07-03 06:54:56 +08:00 committed by GitHub
parent 7a05b3ca87
commit 28e4c0edf5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 317 additions and 55 deletions

View file

@ -31,11 +31,9 @@ namespace training {
// The following is a list of ops, as well as functions, that will
// continue to use 32-bit precision. Others will used reduced precision.
static const std::unordered_set<std::string> FP32_Nodes = {
"SparseSoftmaxCrossEntropy",
"SparseSoftmaxCrossEntropyGrad",
"SoftmaxCrossEntropyLoss",
"SoftmaxCrossEntropyLossGrad"};
// Loss Ops and loss grad Ops are now handled by LossSubgraph, so currently this set is empty.
// If in the future there is new FP32 Op, we can add it here without changing code on other place.
static const std::unordered_set<std::string> FP32_Nodes = {};
bool IsFP32Node(const Node* node) {
return FP32_Nodes.find(node->OpType()) != FP32_Nodes.cend();
@ -49,15 +47,13 @@ static const std::unordered_map<std::string, std::vector<int>> stage1_fp32_node_
{"DropoutGrad", {2}},
};
// Currently the list here is same as stage1 above due to empty FP32_Nodes.
// It's possibile we will have more FP32 nodes added, this map will also be extended.
static const std::unordered_map<std::string, std::vector<int>> stage2_fp32_node_args = {
{"TrainableDropout", {1}},
{"TrainableDropoutGrad", {2}},
{"Dropout", {1}},
{"DropoutGrad", {2}},
{"SparseSoftmaxCrossEntropy", {0, 2}},
{"SparseSoftmaxCrossEntropyGrad", {0, 1, 3}},
{"SoftmaxCrossEntropyLoss", {0, 2}},
{"SoftmaxCrossEntropyLossGrad", {0, 1, 3}},
};
bool IsFP32(const std::unordered_map<std::string, std::vector<int>>& map, std::string opname, int argnum) {
@ -70,10 +66,30 @@ bool IsFP32(const std::unordered_map<std::string, std::vector<int>>& map, std::s
}
}
static const std::string loss_scale_input = "loss_scale";
static const std::unordered_set<std::string> loss_subgraph_entry_nodes = {
"SparseSoftmaxCrossEntropy",
"SoftmaxCrossEntropyLoss"};
static const std::unordered_set<std::string> loss_subgraph_exit_nodes = {
"SparseSoftmaxCrossEntropyGrad",
"SoftmaxCrossEntropyLossGrad"};
static bool IsLossSubgraphEntryNode(const Node* node) {
return loss_subgraph_entry_nodes.find(node->OpType()) != loss_subgraph_entry_nodes.cend();
}
static bool IsLossSubgraphExitNode(const Node* node) {
return loss_subgraph_exit_nodes.find(node->OpType()) != loss_subgraph_exit_nodes.cend();
}
// Separate the consumer nodes of `arg` into two groups: FP32 vs FP16
// The argument `fp32_node_args` specifies the cases where the `arg` should be 32-bit float.
// The argument `fp32_node_args_by_op_type` specifies the cases where the `arg` should be 32-bit float using op type.
// The argument `fp32_node_args_by_node` specifies the cases where the `arg` should be 32-bit float using node pointer.
static void GetConsumerNodeInputs(onnxruntime::Graph& graph,
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args,
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args_by_op_type,
const std::unordered_map<Node*, std::vector<int>>& fp32_node_args_by_node,
const NodeArg* arg,
std::vector<std::pair<Node*, int>>& fp16_inputs,
std::vector<std::pair<Node*, int>>& fp32_inputs) {
@ -91,15 +107,17 @@ static void GetConsumerNodeInputs(onnxruntime::Graph& graph,
continue;
}
auto it = fp32_node_args.find(node->OpType());
if (it == fp32_node_args.cend()) {
fp16_inputs.push_back({node, node_arg_slot});
auto it = fp32_node_args_by_op_type.find(node->OpType());
if (it != fp32_node_args_by_op_type.cend() &&
std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) {
fp32_inputs.push_back({node, node_arg_slot});
} else {
const auto index_it = std::find(it->second.cbegin(), it->second.cend(), node_arg_slot);
if (index_it == it->second.cend()) {
fp16_inputs.push_back({node, node_arg_slot});
} else {
auto it2 = fp32_node_args_by_node.find(node);
if (it2 != fp32_node_args_by_node.cend() &&
std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) {
fp32_inputs.push_back({node, node_arg_slot});
} else {
fp16_inputs.push_back({node, node_arg_slot});
}
}
}
@ -120,9 +138,11 @@ static void RewireCastedNodeArg(onnxruntime::Graph& graph,
}
// This function tries casting `arg` to `element_type`.
// The argument `fp32_node_args` specifies the cases where the `arg` should be 32-bit float.
// The argument `fp32_node_args_by_op_type` specifies the cases where the `arg` should be 32-bit float using op type.
// The argument `fp32_node_args_by_node` specifies the cases where the `arg` should be 32-bit float using node pointer.
static Status CastNodeArg(onnxruntime::Graph& graph,
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args,
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args_by_op_type,
const std::unordered_map<Node*, std::vector<int>>& fp32_node_args_by_node,
NodeArg* arg,
ONNX_NAMESPACE::TensorProto_DataType elem_type) {
if (arg == nullptr) {
@ -135,7 +155,7 @@ static Status CastNodeArg(onnxruntime::Graph& graph,
// Get consumer nodes of the input `arg`
std::vector<std::pair<Node*, int>> fp16_inputs;
std::vector<std::pair<Node*, int>> fp32_inputs;
GetConsumerNodeInputs(graph, fp32_node_args, arg, fp16_inputs, fp32_inputs);
GetConsumerNodeInputs(graph, fp32_node_args_by_op_type, fp32_node_args_by_node, arg, fp16_inputs, fp32_inputs);
if ((elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && fp16_inputs.empty()) ||
(elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT && fp32_inputs.empty())) {
return Status::OK();
@ -216,24 +236,207 @@ static Status CastNodeArg(onnxruntime::Graph& graph,
return Status::OK();
}
Status TransformConstants(Graph& graph) {
struct LossSubgraph {
// All nodes belong to this subgraph.
std::unordered_set<Node*> nodes_;
// NodeArgs that are inputs of this subgraph from outside, which need to be converted to FP32.
std::unordered_set<NodeArg*> to_fp32_inputs_;
// NodeArgs that are outputs of this subgraph to outside, which need to be converted to FP16.
std::unordered_set<NodeArg*> to_fp16_outputs_;
// Nodes that take float input from outside of subgraph, the input indices are also saved.
// It's useful when calling CastNodeArg, so FP32 inputs will no need to be converted.
std::unordered_map<Node*, std::vector<int>> fp32_node_args_;
LossSubgraph(Graph& graph) {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
// Get the nodes related to loss scale. It's a Mul node and it's grad nodes.
// We initialize loss subgraph only when there is loss scale as input.
std::vector<Node*> loss_scale_consumers = graph.GetMutableConsumerNodes(loss_scale_input);
if (loss_scale_consumers.size() == 0) {
return;
}
nodes_.insert(loss_scale_consumers.begin(), loss_scale_consumers.end());
for (Node* node : loss_scale_consumers) {
for (const NodeArg* output : node->OutputDefs()) {
std::vector<Node*> level2_consumers = graph.GetMutableConsumerNodes(output->Name());
nodes_.insert(level2_consumers.begin(), level2_consumers.end());
}
}
// The node number here depends on how to implement the gradient of Mul.
// Add this check here for safety at certain level.
ORT_ENFORCE(nodes_.size() == 3,
"The node number of the loss scale and it's grad subgraph is expected to be 3.");
// Check if graph contains any loss Op from the white-list.
// If not, then above loss scale related nodes are all we need.
bool has_loss_subgraph_entry_node = false;
for (auto index : order) {
if (IsLossSubgraphEntryNode(graph.GetNode(index))) {
has_loss_subgraph_entry_node = true;
break;
}
}
// If it contains one or more loss Ops from white-list, travel the graph again to get the whole loss subgraph.
if (has_loss_subgraph_entry_node) {
for (auto index : order) {
Node* node = graph.GetNode(index);
if (IsLossSubgraphEntryNode(node) || IsLossSubgraphExitNode(node)) {
nodes_.insert(node);
} else {
// For other nodes, if it consumes any output of any node from loss subgraph, it also belongs to loss subgraph.
bool part_of_loss_subgraph = false;
for (NodeArg* input : node->MutableInputDefs()) {
Node* producer_node = graph.GetMutableProducerNode(input->Name());
if (producer_node != nullptr &&
!IsLossSubgraphExitNode(producer_node) &&
nodes_.find(producer_node) != nodes_.cend()) {
part_of_loss_subgraph = true;
break;
}
}
if (part_of_loss_subgraph) {
nodes_.insert(node);
}
}
}
}
// We now have all the nodes of the loss subgraph. Now get all float inputs from outside.
for (Node* node : nodes_) {
int index = 0;
for (NodeArg* input : node->MutableInputDefs()) {
if (input->Name() != loss_scale_input && // loss_scale input will keep FP32, no need to handle here.
input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
// If its producer is from outside, it's one of the inputs of this subgraph.
Node* producer_node = graph.GetMutableProducerNode(input->Name());
if (producer_node == nullptr || nodes_.find(producer_node) == nodes_.cend()) {
to_fp32_inputs_.insert(input);
if (fp32_node_args_.find(node) == fp32_node_args_.cend()) {
fp32_node_args_[node] = {index};
} else {
fp32_node_args_[node].push_back(index);
}
}
}
index++;
}
// Get all float outputs to outside of the subgraph. They will be converted to FP16.
for (NodeArg* output : node->MutableOutputDefs()) {
if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
!ContainsAllConsumers(graph, output->Name())) {
to_fp16_outputs_.insert(output);
}
}
}
}
bool Contains(Node* node) {
return nodes_.find(node) != nodes_.cend();
}
// Check if this loss subgraph contains all the consumers of given Arg.
bool ContainsAllConsumers(Graph& graph, const std::string arg_name) {
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(arg_name);
for (Node* node : consumer_nodes) {
if (nodes_.find(node) == nodes_.cend()) {
return false;
}
}
return true;
}
// For those inputs and constants that are already handled, remove them from the to_fp32 list.
void RemoveFromToFP32Inputs(const std::string& arg_name) {
auto it = to_fp32_inputs_.begin();
while (it != to_fp32_inputs_.end()) {
if ((*it)->Name() == arg_name) {
it = to_fp32_inputs_.erase(it);
} else {
++it;
}
}
}
std::unordered_map<Node*, std::vector<int>>& GetFP32NodeArgs() {
return fp32_node_args_;
}
// Once all inputs, constants, and function calls are handled, it's time to convert all
// inputs to FP32, and convert all outputs to FP16.
Status ProcessInputsAndOutputs(Graph& graph) {
for (auto* node_arg : to_fp32_inputs_) {
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
stage1_fp32_node_args,
fp32_node_args_,
node_arg,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
}
for (auto* node_arg : to_fp16_outputs_) {
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
stage1_fp32_node_args,
fp32_node_args_,
node_arg,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
}
return Status::OK();
}
};
Status TransformConstants(Graph& graph, LossSubgraph* p_loss_subgraph = nullptr) {
// This pass does not require topological sort order: okay to visit nodes in any order.
// We identify nodeargs to be converted to FP16 first, and then convert them separately
// to avoid modifying the graph while iterating through it.
std::unordered_set<NodeArg*> toFP16;
for (auto& node : graph.Nodes()) {
// Ignore any node in loss subgraph.
if (p_loss_subgraph != nullptr && p_loss_subgraph->Contains(&node)) {
continue;
}
const std::string& optype = node.OpType();
// TODO: Why do we need to handle "Cast" here?
if ((optype == "Constant") || (optype == "Cast") || (optype == "ConstantOfShape")) {
for (NodeArg* output : node.MutableOutputDefs()) {
if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
toFP16.insert(output);
if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
// If all consumers are from loss subgraph, don't convert it.
if (p_loss_subgraph == nullptr || !p_loss_subgraph->ContainsAllConsumers(graph, output->Name())) {
toFP16.insert(output);
}
if (p_loss_subgraph != nullptr) {
// If it's one of loss subgraph's input, remove it from the to-convert set since it's already handled.
p_loss_subgraph->RemoveFromToFP32Inputs(output->Name());
}
}
}
}
}
for (auto* tensor : toFP16) {
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage1_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
ORT_RETURN_IF_ERROR(
CastNodeArg(graph,
stage1_fp32_node_args,
p_loss_subgraph != nullptr ?
p_loss_subgraph->GetFP32NodeArgs() :
std::unordered_map<Node*, std::vector<int>>(),
tensor,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
}
return Status::OK();
}
@ -241,7 +444,8 @@ Status TransformConstants(Graph& graph) {
// as SparseSoftmaxCrossEntropy where FP32 precision is required.
// Converts fp16 tensor --> Op --> fp16 tensor to
// fp16 tensor --> Cast --> fp32 tensor --> Op --> fp32 tensor --> Cast --> fp16 tensor
Status TransformStage2(Graph& graph) {
Status TransformStage2(Graph& graph,
const std::unordered_map<Node*, std::vector<int>>& loss_subgraph_fp32_node_args = {}) {
// This pass does not require topological sort order: okay to visit nodes in any order.
std::unordered_set<NodeArg *> toFP16, toFP32;
for (auto& node : graph.Nodes()) {
@ -260,13 +464,21 @@ Status TransformStage2(Graph& graph) {
}
}
for (auto* tensor : toFP32)
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage2_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
stage2_fp32_node_args,
loss_subgraph_fp32_node_args,
tensor,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
for (auto* tensor : toFP16)
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage2_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
stage2_fp32_node_args,
loss_subgraph_fp32_node_args,
tensor,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
return Status::OK();
}
static Status HandleFunctionCalls(Graph& graph);
static Status HandleFunctionCalls(Graph& graph, LossSubgraph* p_loss_subgraph = nullptr);
// TODO: Ideally, we should not need to transform a function-body here.
// Ideally, for any full-precision function F, there should be a corresponding 16-bit precision
@ -304,7 +516,11 @@ static Status HandleFunctionBody(const Function& node_func) {
// Introduce cast to full-precision if required:
// TODO: fix const_cast; Graph doesn't provide us a method "GetMutableInputs".
NodeArg* mutable_input = const_cast<NodeArg*>(input);
CastNodeArg(graph, stage1_fp32_node_args, mutable_input, ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
CastNodeArg(graph,
stage1_fp32_node_args,
std::unordered_map<Node*, std::vector<int>>(),
mutable_input,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
}
}
@ -329,24 +545,30 @@ static Status HandleFunctionBody(const Function& node_func) {
return status;
}
static Status HandleFunctionCalls(Graph& graph) {
static Status HandleFunctionCalls(Graph& graph, LossSubgraph* p_loss_subgraph) {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
for (auto index : order) {
Node* node = graph.GetNode(index);
if (!IsFP32Node(node)) { // Bodies of FP32 Functions are not transformed
const Function* node_func = node->GetFunctionBody();
if (nullptr != node_func) {
ORT_RETURN_IF_ERROR(HandleFunctionBody(*node_func));
}
// Bodies of FP32 Functions are not transformed.
if (IsFP32Node(node) ||
(p_loss_subgraph != nullptr && p_loss_subgraph->Contains(node))) {
continue;
}
const Function* node_func = node->GetFunctionBody();
if (nullptr != node_func) {
ORT_RETURN_IF_ERROR(HandleFunctionBody(*node_func));
}
}
return Status::OK();
}
// Create FP16 NodeArg and update the consumers of arg with new FP16 NodeArg.
static NodeArg* CreateFP16NodeArgAndUpdateConsumers(Graph& graph,
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args,
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args_by_op_type,
const std::unordered_map<Node*, std::vector<int>>& fp32_node_args_by_node,
const NodeArg* arg) {
ORT_ENFORCE(arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
"data type is not float");
@ -360,7 +582,7 @@ static NodeArg* CreateFP16NodeArgAndUpdateConsumers(Graph& graph,
// Check consumer nodes
std::vector<std::pair<Node*, int>> fp16_inputs;
std::vector<std::pair<Node*, int>> fp32_inputs;
GetConsumerNodeInputs(graph, fp32_node_args, arg, fp16_inputs, fp32_inputs);
GetConsumerNodeInputs(graph, fp32_node_args_by_op_type, fp32_node_args_by_node, arg, fp16_inputs, fp32_inputs);
if (fp16_inputs.empty()) {
return nullptr;
}
@ -376,6 +598,9 @@ Status TransformGraphForMixedPrecision(Graph& graph,
const std::unordered_set<std::string>& weights_to_train,
bool use_fp16_initializer,
std::unordered_map<std::string, NodeArg*>& fp32_weight_name_to_fp16_node_arg) {
// Stag 0: Initialize loss subgraph.
LossSubgraph loss_subgraph(graph);
// Stage 1: Convert whole graph including forward and backward to FP16
// Initialize function body for all function nodes
// This is required to make sure after converting inputs\weights to FP16
@ -385,10 +610,23 @@ Status TransformGraphForMixedPrecision(Graph& graph,
}
// Insert Cast node to convert inputs from FP32 to FP16
// If all consumers are from loss graph, don't convert it, and remove it from To-32 loss graph inputs.
for (const NodeArg* input : graph.GetInputs()) {
if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
ORT_RETURN_IF_ERROR(
CastNodeArg(graph, stage1_fp32_node_args, graph.GetNodeArg(input->Name()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
// Input loss_scale will always keep as FP32.
if (input->Name() != loss_scale_input &&
input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
// If all consumers are from loss subgraph, no need to convert.
if (!loss_subgraph.ContainsAllConsumers(graph, input->Name())) {
ORT_RETURN_IF_ERROR(
CastNodeArg(graph,
stage1_fp32_node_args,
loss_subgraph.GetFP32NodeArgs(),
graph.GetNodeArg(input->Name()),
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
}
// Remove it from the to-convert set since it's already handled.
loss_subgraph.RemoveFromToFP32Inputs(input->Name());
}
}
@ -399,18 +637,31 @@ Status TransformGraphForMixedPrecision(Graph& graph,
for (const auto& kv : initialized_tensors) {
NodeArg* input = graph.GetNodeArg(kv.first);
if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
if (use_fp16_initializer) {
NodeArg* fp16_weight_arg = CreateFP16NodeArgAndUpdateConsumers(graph, stage1_fp32_node_args, input);
if (fp16_weight_arg != nullptr) {
fp16_initializers.emplace_back(fp16_weight_arg->Name(), kv.second);
const auto it = weights_to_train.find(kv.first);
if (it != weights_to_train.cend()) {
fp32_weight_name_to_fp16_node_arg_result[kv.first] = fp16_weight_arg;
// If all consumers are from loss graph, don't convert it.
if (!loss_subgraph.ContainsAllConsumers(graph, input->Name())) {
if (use_fp16_initializer) {
NodeArg* fp16_weight_arg = CreateFP16NodeArgAndUpdateConsumers(graph,
stage1_fp32_node_args,
loss_subgraph.GetFP32NodeArgs(),
input);
if (fp16_weight_arg != nullptr) {
fp16_initializers.emplace_back(fp16_weight_arg->Name(), kv.second);
const auto it = weights_to_train.find(kv.first);
if (it != weights_to_train.cend()) {
fp32_weight_name_to_fp16_node_arg_result[kv.first] = fp16_weight_arg;
}
}
} else {
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
stage1_fp32_node_args,
loss_subgraph.GetFP32NodeArgs(),
input,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
}
} else {
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage1_fp32_node_args, input, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
}
// Remove it from the to-convert set since it's already handled.
loss_subgraph.RemoveFromToFP32Inputs(input->Name());
}
}
@ -426,7 +677,7 @@ Status TransformGraphForMixedPrecision(Graph& graph,
for (auto& node : graph.Nodes()) {
// For send and recv node, if the tensor being sent or received is FP32, update its
// attribute and change it to FP16.
if (!node.OpType().compare("Send") || !node.OpType().compare("Recv")) {
if ((!node.OpType().compare("Send") || !node.OpType().compare("Recv")) && !loss_subgraph.Contains(&node)) {
auto& attributes = node.GetMutableAttributes();
auto* element_type = &(attributes.find("element_types")->second);
int ints_size = element_type->ints_size();
@ -441,10 +692,13 @@ Status TransformGraphForMixedPrecision(Graph& graph,
}
// Handle implicit data type casting nodes such as Cast, ConstantOfShape
ORT_RETURN_IF_ERROR(TransformConstants(graph));
ORT_RETURN_IF_ERROR(TransformConstants(graph, &loss_subgraph));
// Handle function body
ORT_RETURN_IF_ERROR(HandleFunctionCalls(graph));
ORT_RETURN_IF_ERROR(HandleFunctionCalls(graph, &loss_subgraph));
// Handle loss graph inputs and outputs.
ORT_RETURN_IF_ERROR(loss_subgraph.ProcessInputsAndOutputs(graph));
// At this point, the model has been transformed to a valid FP16 model.
@ -454,7 +708,7 @@ Status TransformGraphForMixedPrecision(Graph& graph,
ORT_RETURN_IF_ERROR(graph.Resolve(options));
TransformStage2(graph);
TransformStage2(graph, loss_subgraph.GetFP32NodeArgs());
ORT_RETURN_IF_ERROR(graph.Resolve(options));

View file

@ -794,7 +794,15 @@ const DataTransferManager& TrainingSession::GetDataTransferManager() const {
bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) const {
auto output_producer_node = model_->MainGraph().GetProducerNode(output_name);
ORT_ENFORCE(output_producer_node != nullptr, "Output: " + output_name + " is not produced by any node.");
return IsFP32Node(output_producer_node);
for (auto output : output_producer_node->OutputDefs()) {
if (output->Name() == output_name && output->TypeAsProto() != nullptr && output->TypeAsProto()->has_tensor_type()
&& output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return true;
}
}
return false;
}
common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) {