mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Keep loss_scale and Whole Loss Subgraph in FP32 during Mixed Precision Training (#4268)
* Keep loss subgraph as FP32 when mixed-p training. * Fix case where there is no white-list loss op. * Get nodes from loss_scale instead of whitelist. * rename const variables. Co-authored-by: Vincent Wang <weicwang@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
7a05b3ca87
commit
28e4c0edf5
2 changed files with 317 additions and 55 deletions
|
|
@ -31,11 +31,9 @@ namespace training {
|
|||
|
||||
// The following is a list of ops, as well as functions, that will
|
||||
// continue to use 32-bit precision. Others will used reduced precision.
|
||||
static const std::unordered_set<std::string> FP32_Nodes = {
|
||||
"SparseSoftmaxCrossEntropy",
|
||||
"SparseSoftmaxCrossEntropyGrad",
|
||||
"SoftmaxCrossEntropyLoss",
|
||||
"SoftmaxCrossEntropyLossGrad"};
|
||||
// Loss Ops and loss grad Ops are now handled by LossSubgraph, so currently this set is empty.
|
||||
// If in the future there is new FP32 Op, we can add it here without changing code on other place.
|
||||
static const std::unordered_set<std::string> FP32_Nodes = {};
|
||||
|
||||
bool IsFP32Node(const Node* node) {
|
||||
return FP32_Nodes.find(node->OpType()) != FP32_Nodes.cend();
|
||||
|
|
@ -49,15 +47,13 @@ static const std::unordered_map<std::string, std::vector<int>> stage1_fp32_node_
|
|||
{"DropoutGrad", {2}},
|
||||
};
|
||||
|
||||
// Currently the list here is same as stage1 above due to empty FP32_Nodes.
|
||||
// It's possibile we will have more FP32 nodes added, this map will also be extended.
|
||||
static const std::unordered_map<std::string, std::vector<int>> stage2_fp32_node_args = {
|
||||
{"TrainableDropout", {1}},
|
||||
{"TrainableDropoutGrad", {2}},
|
||||
{"Dropout", {1}},
|
||||
{"DropoutGrad", {2}},
|
||||
{"SparseSoftmaxCrossEntropy", {0, 2}},
|
||||
{"SparseSoftmaxCrossEntropyGrad", {0, 1, 3}},
|
||||
{"SoftmaxCrossEntropyLoss", {0, 2}},
|
||||
{"SoftmaxCrossEntropyLossGrad", {0, 1, 3}},
|
||||
};
|
||||
|
||||
bool IsFP32(const std::unordered_map<std::string, std::vector<int>>& map, std::string opname, int argnum) {
|
||||
|
|
@ -70,10 +66,30 @@ bool IsFP32(const std::unordered_map<std::string, std::vector<int>>& map, std::s
|
|||
}
|
||||
}
|
||||
|
||||
static const std::string loss_scale_input = "loss_scale";
|
||||
|
||||
static const std::unordered_set<std::string> loss_subgraph_entry_nodes = {
|
||||
"SparseSoftmaxCrossEntropy",
|
||||
"SoftmaxCrossEntropyLoss"};
|
||||
|
||||
static const std::unordered_set<std::string> loss_subgraph_exit_nodes = {
|
||||
"SparseSoftmaxCrossEntropyGrad",
|
||||
"SoftmaxCrossEntropyLossGrad"};
|
||||
|
||||
static bool IsLossSubgraphEntryNode(const Node* node) {
|
||||
return loss_subgraph_entry_nodes.find(node->OpType()) != loss_subgraph_entry_nodes.cend();
|
||||
}
|
||||
|
||||
static bool IsLossSubgraphExitNode(const Node* node) {
|
||||
return loss_subgraph_exit_nodes.find(node->OpType()) != loss_subgraph_exit_nodes.cend();
|
||||
}
|
||||
|
||||
// Separate the consumer nodes of `arg` into two groups: FP32 vs FP16
|
||||
// The argument `fp32_node_args` specifies the cases where the `arg` should be 32-bit float.
|
||||
// The argument `fp32_node_args_by_op_type` specifies the cases where the `arg` should be 32-bit float using op type.
|
||||
// The argument `fp32_node_args_by_node` specifies the cases where the `arg` should be 32-bit float using node pointer.
|
||||
static void GetConsumerNodeInputs(onnxruntime::Graph& graph,
|
||||
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args,
|
||||
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args_by_op_type,
|
||||
const std::unordered_map<Node*, std::vector<int>>& fp32_node_args_by_node,
|
||||
const NodeArg* arg,
|
||||
std::vector<std::pair<Node*, int>>& fp16_inputs,
|
||||
std::vector<std::pair<Node*, int>>& fp32_inputs) {
|
||||
|
|
@ -91,15 +107,17 @@ static void GetConsumerNodeInputs(onnxruntime::Graph& graph,
|
|||
continue;
|
||||
}
|
||||
|
||||
auto it = fp32_node_args.find(node->OpType());
|
||||
if (it == fp32_node_args.cend()) {
|
||||
fp16_inputs.push_back({node, node_arg_slot});
|
||||
auto it = fp32_node_args_by_op_type.find(node->OpType());
|
||||
if (it != fp32_node_args_by_op_type.cend() &&
|
||||
std::find(it->second.cbegin(), it->second.cend(), node_arg_slot) != it->second.cend()) {
|
||||
fp32_inputs.push_back({node, node_arg_slot});
|
||||
} else {
|
||||
const auto index_it = std::find(it->second.cbegin(), it->second.cend(), node_arg_slot);
|
||||
if (index_it == it->second.cend()) {
|
||||
fp16_inputs.push_back({node, node_arg_slot});
|
||||
} else {
|
||||
auto it2 = fp32_node_args_by_node.find(node);
|
||||
if (it2 != fp32_node_args_by_node.cend() &&
|
||||
std::find(it2->second.cbegin(), it2->second.cend(), node_arg_slot) != it2->second.cend()) {
|
||||
fp32_inputs.push_back({node, node_arg_slot});
|
||||
} else {
|
||||
fp16_inputs.push_back({node, node_arg_slot});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -120,9 +138,11 @@ static void RewireCastedNodeArg(onnxruntime::Graph& graph,
|
|||
}
|
||||
|
||||
// This function tries casting `arg` to `element_type`.
|
||||
// The argument `fp32_node_args` specifies the cases where the `arg` should be 32-bit float.
|
||||
// The argument `fp32_node_args_by_op_type` specifies the cases where the `arg` should be 32-bit float using op type.
|
||||
// The argument `fp32_node_args_by_node` specifies the cases where the `arg` should be 32-bit float using node pointer.
|
||||
static Status CastNodeArg(onnxruntime::Graph& graph,
|
||||
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args,
|
||||
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args_by_op_type,
|
||||
const std::unordered_map<Node*, std::vector<int>>& fp32_node_args_by_node,
|
||||
NodeArg* arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType elem_type) {
|
||||
if (arg == nullptr) {
|
||||
|
|
@ -135,7 +155,7 @@ static Status CastNodeArg(onnxruntime::Graph& graph,
|
|||
// Get consumer nodes of the input `arg`
|
||||
std::vector<std::pair<Node*, int>> fp16_inputs;
|
||||
std::vector<std::pair<Node*, int>> fp32_inputs;
|
||||
GetConsumerNodeInputs(graph, fp32_node_args, arg, fp16_inputs, fp32_inputs);
|
||||
GetConsumerNodeInputs(graph, fp32_node_args_by_op_type, fp32_node_args_by_node, arg, fp16_inputs, fp32_inputs);
|
||||
if ((elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && fp16_inputs.empty()) ||
|
||||
(elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT && fp32_inputs.empty())) {
|
||||
return Status::OK();
|
||||
|
|
@ -216,24 +236,207 @@ static Status CastNodeArg(onnxruntime::Graph& graph,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformConstants(Graph& graph) {
|
||||
struct LossSubgraph {
|
||||
// All nodes belong to this subgraph.
|
||||
std::unordered_set<Node*> nodes_;
|
||||
|
||||
// NodeArgs that are inputs of this subgraph from outside, which need to be converted to FP32.
|
||||
std::unordered_set<NodeArg*> to_fp32_inputs_;
|
||||
|
||||
// NodeArgs that are outputs of this subgraph to outside, which need to be converted to FP16.
|
||||
std::unordered_set<NodeArg*> to_fp16_outputs_;
|
||||
|
||||
// Nodes that take float input from outside of subgraph, the input indices are also saved.
|
||||
// It's useful when calling CastNodeArg, so FP32 inputs will no need to be converted.
|
||||
std::unordered_map<Node*, std::vector<int>> fp32_node_args_;
|
||||
|
||||
LossSubgraph(Graph& graph) {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
// Get the nodes related to loss scale. It's a Mul node and it's grad nodes.
|
||||
// We initialize loss subgraph only when there is loss scale as input.
|
||||
std::vector<Node*> loss_scale_consumers = graph.GetMutableConsumerNodes(loss_scale_input);
|
||||
if (loss_scale_consumers.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
nodes_.insert(loss_scale_consumers.begin(), loss_scale_consumers.end());
|
||||
for (Node* node : loss_scale_consumers) {
|
||||
for (const NodeArg* output : node->OutputDefs()) {
|
||||
std::vector<Node*> level2_consumers = graph.GetMutableConsumerNodes(output->Name());
|
||||
nodes_.insert(level2_consumers.begin(), level2_consumers.end());
|
||||
}
|
||||
}
|
||||
|
||||
// The node number here depends on how to implement the gradient of Mul.
|
||||
// Add this check here for safety at certain level.
|
||||
ORT_ENFORCE(nodes_.size() == 3,
|
||||
"The node number of the loss scale and it's grad subgraph is expected to be 3.");
|
||||
|
||||
// Check if graph contains any loss Op from the white-list.
|
||||
// If not, then above loss scale related nodes are all we need.
|
||||
bool has_loss_subgraph_entry_node = false;
|
||||
for (auto index : order) {
|
||||
if (IsLossSubgraphEntryNode(graph.GetNode(index))) {
|
||||
has_loss_subgraph_entry_node = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If it contains one or more loss Ops from white-list, travel the graph again to get the whole loss subgraph.
|
||||
if (has_loss_subgraph_entry_node) {
|
||||
for (auto index : order) {
|
||||
Node* node = graph.GetNode(index);
|
||||
if (IsLossSubgraphEntryNode(node) || IsLossSubgraphExitNode(node)) {
|
||||
nodes_.insert(node);
|
||||
} else {
|
||||
// For other nodes, if it consumes any output of any node from loss subgraph, it also belongs to loss subgraph.
|
||||
bool part_of_loss_subgraph = false;
|
||||
for (NodeArg* input : node->MutableInputDefs()) {
|
||||
Node* producer_node = graph.GetMutableProducerNode(input->Name());
|
||||
if (producer_node != nullptr &&
|
||||
!IsLossSubgraphExitNode(producer_node) &&
|
||||
nodes_.find(producer_node) != nodes_.cend()) {
|
||||
part_of_loss_subgraph = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (part_of_loss_subgraph) {
|
||||
nodes_.insert(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We now have all the nodes of the loss subgraph. Now get all float inputs from outside.
|
||||
for (Node* node : nodes_) {
|
||||
int index = 0;
|
||||
for (NodeArg* input : node->MutableInputDefs()) {
|
||||
if (input->Name() != loss_scale_input && // loss_scale input will keep FP32, no need to handle here.
|
||||
input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
// If its producer is from outside, it's one of the inputs of this subgraph.
|
||||
Node* producer_node = graph.GetMutableProducerNode(input->Name());
|
||||
if (producer_node == nullptr || nodes_.find(producer_node) == nodes_.cend()) {
|
||||
to_fp32_inputs_.insert(input);
|
||||
if (fp32_node_args_.find(node) == fp32_node_args_.cend()) {
|
||||
fp32_node_args_[node] = {index};
|
||||
} else {
|
||||
fp32_node_args_[node].push_back(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
index++;
|
||||
}
|
||||
|
||||
// Get all float outputs to outside of the subgraph. They will be converted to FP16.
|
||||
for (NodeArg* output : node->MutableOutputDefs()) {
|
||||
if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
!ContainsAllConsumers(graph, output->Name())) {
|
||||
to_fp16_outputs_.insert(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Contains(Node* node) {
|
||||
return nodes_.find(node) != nodes_.cend();
|
||||
}
|
||||
|
||||
// Check if this loss subgraph contains all the consumers of given Arg.
|
||||
bool ContainsAllConsumers(Graph& graph, const std::string arg_name) {
|
||||
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(arg_name);
|
||||
for (Node* node : consumer_nodes) {
|
||||
if (nodes_.find(node) == nodes_.cend()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// For those inputs and constants that are already handled, remove them from the to_fp32 list.
|
||||
void RemoveFromToFP32Inputs(const std::string& arg_name) {
|
||||
auto it = to_fp32_inputs_.begin();
|
||||
while (it != to_fp32_inputs_.end()) {
|
||||
if ((*it)->Name() == arg_name) {
|
||||
it = to_fp32_inputs_.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<Node*, std::vector<int>>& GetFP32NodeArgs() {
|
||||
return fp32_node_args_;
|
||||
}
|
||||
|
||||
// Once all inputs, constants, and function calls are handled, it's time to convert all
|
||||
// inputs to FP32, and convert all outputs to FP16.
|
||||
Status ProcessInputsAndOutputs(Graph& graph) {
|
||||
for (auto* node_arg : to_fp32_inputs_) {
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
|
||||
stage1_fp32_node_args,
|
||||
fp32_node_args_,
|
||||
node_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
}
|
||||
|
||||
for (auto* node_arg : to_fp16_outputs_) {
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
|
||||
stage1_fp32_node_args,
|
||||
fp32_node_args_,
|
||||
node_arg,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
Status TransformConstants(Graph& graph, LossSubgraph* p_loss_subgraph = nullptr) {
|
||||
// This pass does not require topological sort order: okay to visit nodes in any order.
|
||||
// We identify nodeargs to be converted to FP16 first, and then convert them separately
|
||||
// to avoid modifying the graph while iterating through it.
|
||||
std::unordered_set<NodeArg*> toFP16;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
// Ignore any node in loss subgraph.
|
||||
if (p_loss_subgraph != nullptr && p_loss_subgraph->Contains(&node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string& optype = node.OpType();
|
||||
// TODO: Why do we need to handle "Cast" here?
|
||||
if ((optype == "Constant") || (optype == "Cast") || (optype == "ConstantOfShape")) {
|
||||
for (NodeArg* output : node.MutableOutputDefs()) {
|
||||
if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
|
||||
toFP16.insert(output);
|
||||
if (output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
// If all consumers are from loss subgraph, don't convert it.
|
||||
if (p_loss_subgraph == nullptr || !p_loss_subgraph->ContainsAllConsumers(graph, output->Name())) {
|
||||
toFP16.insert(output);
|
||||
}
|
||||
|
||||
if (p_loss_subgraph != nullptr) {
|
||||
// If it's one of loss subgraph's input, remove it from the to-convert set since it's already handled.
|
||||
p_loss_subgraph->RemoveFromToFP32Inputs(output->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* tensor : toFP16) {
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage1_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
CastNodeArg(graph,
|
||||
stage1_fp32_node_args,
|
||||
p_loss_subgraph != nullptr ?
|
||||
p_loss_subgraph->GetFP32NodeArgs() :
|
||||
std::unordered_map<Node*, std::vector<int>>(),
|
||||
tensor,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -241,7 +444,8 @@ Status TransformConstants(Graph& graph) {
|
|||
// as SparseSoftmaxCrossEntropy where FP32 precision is required.
|
||||
// Converts fp16 tensor --> Op --> fp16 tensor to
|
||||
// fp16 tensor --> Cast --> fp32 tensor --> Op --> fp32 tensor --> Cast --> fp16 tensor
|
||||
Status TransformStage2(Graph& graph) {
|
||||
Status TransformStage2(Graph& graph,
|
||||
const std::unordered_map<Node*, std::vector<int>>& loss_subgraph_fp32_node_args = {}) {
|
||||
// This pass does not require topological sort order: okay to visit nodes in any order.
|
||||
std::unordered_set<NodeArg *> toFP16, toFP32;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
|
|
@ -260,13 +464,21 @@ Status TransformStage2(Graph& graph) {
|
|||
}
|
||||
}
|
||||
for (auto* tensor : toFP32)
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage2_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
|
||||
stage2_fp32_node_args,
|
||||
loss_subgraph_fp32_node_args,
|
||||
tensor,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
for (auto* tensor : toFP16)
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage2_fp32_node_args, tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
|
||||
stage2_fp32_node_args,
|
||||
loss_subgraph_fp32_node_args,
|
||||
tensor,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status HandleFunctionCalls(Graph& graph);
|
||||
static Status HandleFunctionCalls(Graph& graph, LossSubgraph* p_loss_subgraph = nullptr);
|
||||
|
||||
// TODO: Ideally, we should not need to transform a function-body here.
|
||||
// Ideally, for any full-precision function F, there should be a corresponding 16-bit precision
|
||||
|
|
@ -304,7 +516,11 @@ static Status HandleFunctionBody(const Function& node_func) {
|
|||
// Introduce cast to full-precision if required:
|
||||
// TODO: fix const_cast; Graph doesn't provide us a method "GetMutableInputs".
|
||||
NodeArg* mutable_input = const_cast<NodeArg*>(input);
|
||||
CastNodeArg(graph, stage1_fp32_node_args, mutable_input, ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
CastNodeArg(graph,
|
||||
stage1_fp32_node_args,
|
||||
std::unordered_map<Node*, std::vector<int>>(),
|
||||
mutable_input,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -329,24 +545,30 @@ static Status HandleFunctionBody(const Function& node_func) {
|
|||
return status;
|
||||
}
|
||||
|
||||
static Status HandleFunctionCalls(Graph& graph) {
|
||||
static Status HandleFunctionCalls(Graph& graph, LossSubgraph* p_loss_subgraph) {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (auto index : order) {
|
||||
Node* node = graph.GetNode(index);
|
||||
if (!IsFP32Node(node)) { // Bodies of FP32 Functions are not transformed
|
||||
const Function* node_func = node->GetFunctionBody();
|
||||
if (nullptr != node_func) {
|
||||
ORT_RETURN_IF_ERROR(HandleFunctionBody(*node_func));
|
||||
}
|
||||
// Bodies of FP32 Functions are not transformed.
|
||||
if (IsFP32Node(node) ||
|
||||
(p_loss_subgraph != nullptr && p_loss_subgraph->Contains(node))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Function* node_func = node->GetFunctionBody();
|
||||
if (nullptr != node_func) {
|
||||
ORT_RETURN_IF_ERROR(HandleFunctionBody(*node_func));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create FP16 NodeArg and update the consumers of arg with new FP16 NodeArg.
|
||||
static NodeArg* CreateFP16NodeArgAndUpdateConsumers(Graph& graph,
|
||||
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args,
|
||||
const std::unordered_map<std::string, std::vector<int>>& fp32_node_args_by_op_type,
|
||||
const std::unordered_map<Node*, std::vector<int>>& fp32_node_args_by_node,
|
||||
const NodeArg* arg) {
|
||||
ORT_ENFORCE(arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
|
||||
"data type is not float");
|
||||
|
|
@ -360,7 +582,7 @@ static NodeArg* CreateFP16NodeArgAndUpdateConsumers(Graph& graph,
|
|||
// Check consumer nodes
|
||||
std::vector<std::pair<Node*, int>> fp16_inputs;
|
||||
std::vector<std::pair<Node*, int>> fp32_inputs;
|
||||
GetConsumerNodeInputs(graph, fp32_node_args, arg, fp16_inputs, fp32_inputs);
|
||||
GetConsumerNodeInputs(graph, fp32_node_args_by_op_type, fp32_node_args_by_node, arg, fp16_inputs, fp32_inputs);
|
||||
if (fp16_inputs.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -376,6 +598,9 @@ Status TransformGraphForMixedPrecision(Graph& graph,
|
|||
const std::unordered_set<std::string>& weights_to_train,
|
||||
bool use_fp16_initializer,
|
||||
std::unordered_map<std::string, NodeArg*>& fp32_weight_name_to_fp16_node_arg) {
|
||||
// Stag 0: Initialize loss subgraph.
|
||||
LossSubgraph loss_subgraph(graph);
|
||||
|
||||
// Stage 1: Convert whole graph including forward and backward to FP16
|
||||
// Initialize function body for all function nodes
|
||||
// This is required to make sure after converting inputs\weights to FP16
|
||||
|
|
@ -385,10 +610,23 @@ Status TransformGraphForMixedPrecision(Graph& graph,
|
|||
}
|
||||
|
||||
// Insert Cast node to convert inputs from FP32 to FP16
|
||||
// If all consumers are from loss graph, don't convert it, and remove it from To-32 loss graph inputs.
|
||||
for (const NodeArg* input : graph.GetInputs()) {
|
||||
if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
CastNodeArg(graph, stage1_fp32_node_args, graph.GetNodeArg(input->Name()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
// Input loss_scale will always keep as FP32.
|
||||
if (input->Name() != loss_scale_input &&
|
||||
input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
// If all consumers are from loss subgraph, no need to convert.
|
||||
if (!loss_subgraph.ContainsAllConsumers(graph, input->Name())) {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
CastNodeArg(graph,
|
||||
stage1_fp32_node_args,
|
||||
loss_subgraph.GetFP32NodeArgs(),
|
||||
graph.GetNodeArg(input->Name()),
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
}
|
||||
|
||||
// Remove it from the to-convert set since it's already handled.
|
||||
loss_subgraph.RemoveFromToFP32Inputs(input->Name());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -399,18 +637,31 @@ Status TransformGraphForMixedPrecision(Graph& graph,
|
|||
for (const auto& kv : initialized_tensors) {
|
||||
NodeArg* input = graph.GetNodeArg(kv.first);
|
||||
if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
if (use_fp16_initializer) {
|
||||
NodeArg* fp16_weight_arg = CreateFP16NodeArgAndUpdateConsumers(graph, stage1_fp32_node_args, input);
|
||||
if (fp16_weight_arg != nullptr) {
|
||||
fp16_initializers.emplace_back(fp16_weight_arg->Name(), kv.second);
|
||||
const auto it = weights_to_train.find(kv.first);
|
||||
if (it != weights_to_train.cend()) {
|
||||
fp32_weight_name_to_fp16_node_arg_result[kv.first] = fp16_weight_arg;
|
||||
// If all consumers are from loss graph, don't convert it.
|
||||
if (!loss_subgraph.ContainsAllConsumers(graph, input->Name())) {
|
||||
if (use_fp16_initializer) {
|
||||
NodeArg* fp16_weight_arg = CreateFP16NodeArgAndUpdateConsumers(graph,
|
||||
stage1_fp32_node_args,
|
||||
loss_subgraph.GetFP32NodeArgs(),
|
||||
input);
|
||||
if (fp16_weight_arg != nullptr) {
|
||||
fp16_initializers.emplace_back(fp16_weight_arg->Name(), kv.second);
|
||||
const auto it = weights_to_train.find(kv.first);
|
||||
if (it != weights_to_train.cend()) {
|
||||
fp32_weight_name_to_fp16_node_arg_result[kv.first] = fp16_weight_arg;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph,
|
||||
stage1_fp32_node_args,
|
||||
loss_subgraph.GetFP32NodeArgs(),
|
||||
input,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
}
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(CastNodeArg(graph, stage1_fp32_node_args, input, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||
}
|
||||
|
||||
// Remove it from the to-convert set since it's already handled.
|
||||
loss_subgraph.RemoveFromToFP32Inputs(input->Name());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -426,7 +677,7 @@ Status TransformGraphForMixedPrecision(Graph& graph,
|
|||
for (auto& node : graph.Nodes()) {
|
||||
// For send and recv node, if the tensor being sent or received is FP32, update its
|
||||
// attribute and change it to FP16.
|
||||
if (!node.OpType().compare("Send") || !node.OpType().compare("Recv")) {
|
||||
if ((!node.OpType().compare("Send") || !node.OpType().compare("Recv")) && !loss_subgraph.Contains(&node)) {
|
||||
auto& attributes = node.GetMutableAttributes();
|
||||
auto* element_type = &(attributes.find("element_types")->second);
|
||||
int ints_size = element_type->ints_size();
|
||||
|
|
@ -441,10 +692,13 @@ Status TransformGraphForMixedPrecision(Graph& graph,
|
|||
}
|
||||
|
||||
// Handle implicit data type casting nodes such as Cast, ConstantOfShape
|
||||
ORT_RETURN_IF_ERROR(TransformConstants(graph));
|
||||
ORT_RETURN_IF_ERROR(TransformConstants(graph, &loss_subgraph));
|
||||
|
||||
// Handle function body
|
||||
ORT_RETURN_IF_ERROR(HandleFunctionCalls(graph));
|
||||
ORT_RETURN_IF_ERROR(HandleFunctionCalls(graph, &loss_subgraph));
|
||||
|
||||
// Handle loss graph inputs and outputs.
|
||||
ORT_RETURN_IF_ERROR(loss_subgraph.ProcessInputsAndOutputs(graph));
|
||||
|
||||
// At this point, the model has been transformed to a valid FP16 model.
|
||||
|
||||
|
|
@ -454,7 +708,7 @@ Status TransformGraphForMixedPrecision(Graph& graph,
|
|||
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve(options));
|
||||
|
||||
TransformStage2(graph);
|
||||
TransformStage2(graph, loss_subgraph.GetFP32NodeArgs());
|
||||
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve(options));
|
||||
|
||||
|
|
|
|||
|
|
@ -794,7 +794,15 @@ const DataTransferManager& TrainingSession::GetDataTransferManager() const {
|
|||
bool TrainingSession::IsGraphOutputFp32Node(const std::string& output_name) const {
|
||||
auto output_producer_node = model_->MainGraph().GetProducerNode(output_name);
|
||||
ORT_ENFORCE(output_producer_node != nullptr, "Output: " + output_name + " is not produced by any node.");
|
||||
return IsFP32Node(output_producer_node);
|
||||
|
||||
for (auto output : output_producer_node->OutputDefs()) {
|
||||
if (output->Name() == output_name && output->TypeAsProto() != nullptr && output->TypeAsProto()->has_tensor_type()
|
||||
&& output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
common::Status TrainingSession::Run(const RunOptions& run_options, IOBinding& io_binding) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue