mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Add level two optimizations for constant propagation transformation. (#7410)
* Made the python script generating the testcases modular. * Modified RemoveBackToBackCasts function to remove cast even if the parent node has other consumers. * Modified InsertCastNodes to update the graph consistently for other functions to work. * Moved ConcatNames function to the top. * PropagateBackward/SearchUpstream and PropagateFP16CastsFromOutputsToInputs insert FP32 casts if the level >1 in order to propagate FP16 casts backwards. * Added new testcases for level two setting.
This commit is contained in:
parent
f1c3f3fcc1
commit
979d63159b
77 changed files with 296 additions and 151 deletions
|
|
@ -11,10 +11,21 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
// NodeArg to Select consumer node map.
|
||||
typedef std::unordered_map<NodeArg*, std::vector<Node*>> NodeArgToConsumerMap;
|
||||
static std::string GetName(const std::pair<const NodeArg*, std::vector<Node*>>& p) {
|
||||
return p.first->Name();
|
||||
};
|
||||
|
||||
// ConcatNames
|
||||
// Collects all the names from the pointers of the objects stores in the container class C
|
||||
// the class should have a member functions returning a string (or a ref).
|
||||
template <typename C, typename T = typename C::value_type>
|
||||
static std::string ConcatNames(
|
||||
C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) {
|
||||
std::vector<std::string> names;
|
||||
std::transform(items.begin(), items.end(), back_inserter(names), f);
|
||||
return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; });
|
||||
}
|
||||
|
||||
static std::string GetName(const std::pair<const NodeArg*, std::vector<Node*>>& p) {
|
||||
return p.first->Name() + " feeding " + ConcatNames(p.second) + "; ";
|
||||
};
|
||||
// The collection fp16_allow_ops, specifies for a given propagate_cast_ops level, a vector of node op_types that
|
||||
// the code is allowed to propage Cast operations cross. The user may specify a custom list of optypes using level 0.
|
||||
// The opcodes are split into multiple levels. Cast propagation is done based on the level. Level 2 op code
|
||||
|
|
@ -65,7 +76,7 @@ static Status InsertCastNodes(Graph& graph,
|
|||
// data_type is the data type of the Cast output.
|
||||
TensorProto_DataType data_type = is_fp16 ? TensorProto_DataType_FLOAT16 : TensorProto_DataType_FLOAT;
|
||||
TypeProto type_proto;
|
||||
bool is_node_arg_cast_output = IsType(*node_arg, data_type);
|
||||
bool is_node_arg_cast_output = IsType(*node_arg, data_type); // true if the producer node_arg is being replaced
|
||||
TensorProto_DataType new_node_arg_data_type = data_type;
|
||||
|
||||
if (is_node_arg_cast_output) {
|
||||
|
|
@ -94,31 +105,59 @@ static Status InsertCastNodes(Graph& graph,
|
|||
inserted_node_names.insert(cast.Name());
|
||||
Node* producer = graph.GetMutableProducerNode(node_arg->Name());
|
||||
std::vector<Node*> consumers = graph.GetMutableConsumerNodes(node_arg->Name());
|
||||
int output_index = (nullptr != producer) ? optimizer_utils::IndexOfNodeOutput(*producer, *node_arg) : -1;
|
||||
// Update consumers of node_arg to use the output of the cast node
|
||||
int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output);
|
||||
for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) {
|
||||
if (nullptr != consumer && std::find(nodes.begin(), nodes.end(), consumer) != nodes.end() &&
|
||||
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
|
||||
auto& consumer_inputs = consumer->MutableInputDefs();
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
|
||||
if (nullptr != producer) {
|
||||
std::vector<Node*> other_nodes;
|
||||
if (nullptr != producer) {
|
||||
int output_index = optimizer_utils::IndexOfNodeOutput(*producer, *node_arg);
|
||||
for (Node* consumer : consumers) {
|
||||
// Removed the edges to the consumers, getting casted.
|
||||
if (std::find(nodes.begin(), nodes.end(), consumer) != nodes.end()) {
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
|
||||
graph.RemoveEdge(producer->Index(), consumer->Index(), output_index, input_index);
|
||||
} else {
|
||||
other_nodes.push_back(consumer);
|
||||
}
|
||||
std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output);
|
||||
graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index);
|
||||
}
|
||||
if (is_node_arg_cast_output) {
|
||||
// Replace the node_arg with the new_node_arg in the producer outputs
|
||||
auto& producer_outputs = producer->MutableOutputDefs();
|
||||
std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input);
|
||||
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
|
||||
}
|
||||
}
|
||||
if (nullptr != producer) {
|
||||
auto& producer_outputs = producer->MutableOutputDefs();
|
||||
// The following replacement is necessary in case where the output of the cast node is original
|
||||
// output of the producer, for example the original output of the producer may be the graph output.
|
||||
std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input);
|
||||
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input);
|
||||
graph.AddEdge(producer->Index(), cast.Index(), output_index, input_index);
|
||||
// Update consumers of node_arg to use the output of the cast node
|
||||
int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output);
|
||||
for (Node* consumer : consumers) {
|
||||
if (std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
|
||||
if (std::find(nodes.begin(), nodes.end(), consumer) == nodes.end()) {
|
||||
// Consumers not getting casted need to replace input-def if the producer's output-def is changed
|
||||
if (is_node_arg_cast_output) {
|
||||
auto& consumer_inputs = consumer->MutableInputDefs();
|
||||
std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_output, &cast_input);
|
||||
}
|
||||
} else {
|
||||
// Consumers getting casted need to get new edges from the new cast node..
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
|
||||
if (!is_node_arg_cast_output) {
|
||||
auto& consumer_inputs = consumer->MutableInputDefs();
|
||||
std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output);
|
||||
}
|
||||
graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Complete the input/output connections to the new cast node, and update the graph.
|
||||
other_nodes.push_back(&cast);
|
||||
graph.UpdateProducerNode(cast_output.Name(), cast.Index());
|
||||
graph.UpdateConsumerNodes(cast_output.Name(), nodes);
|
||||
graph.UpdateConsumerNodes(cast_input.Name(), other_nodes);
|
||||
if (nullptr != producer) {
|
||||
int cast_input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input);
|
||||
int output_index = optimizer_utils::IndexOfNodeOutput(*producer, cast_input);
|
||||
graph.AddEdge(producer->Index(), cast.Index(), output_index, cast_input_index);
|
||||
if (is_node_arg_cast_output) {
|
||||
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -138,7 +177,7 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
|
|||
if (producer) {
|
||||
if (graph.IsOutput(cast_output)) {
|
||||
// cast_output is a graph output. Replace the cast node with an Identity operator unless node
|
||||
// has other outputs.
|
||||
// has no other outputs.
|
||||
if (producer->GetOutputEdgesCount() == 1) {
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input);
|
||||
graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index);
|
||||
|
|
@ -146,11 +185,11 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
|
|||
std::replace(outputs.begin(), outputs.end(), cast_input, cast_output);
|
||||
graph.UpdateProducerNode(cast_output->Name(), producer->Index());
|
||||
} else {
|
||||
(void) graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"),
|
||||
"Identity",
|
||||
"Created as a place-holder for a graph output",
|
||||
{cast_input},
|
||||
{cast_output});
|
||||
(void)graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"),
|
||||
"Identity",
|
||||
"Created as a place-holder for a graph output",
|
||||
{cast_input},
|
||||
{cast_output});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -183,19 +222,21 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
|
|||
// and the second cast is from FLOAT to FLOAT16.
|
||||
// Condition: The parent cast should have only one output
|
||||
// The inputs is Cast to FLOAT16
|
||||
static bool RemoveBackToBackCasts(Graph& graph, Node* node,
|
||||
static bool RemoveBackToBackCasts(Graph& graph, Node* parent,
|
||||
std::deque<NodeIndex>& removed_nodes,
|
||||
const logging::Logger& logger) {
|
||||
ORT_ENFORCE(IsCastTo(node, TensorProto::FLOAT));
|
||||
ORT_ENFORCE(IsCastTo(parent, TensorProto::FLOAT));
|
||||
bool modified = false;
|
||||
if (graph_utils::CanRemoveNode(graph, *node, logger)) {
|
||||
NodeArg* cast_output = node->MutableOutputDefs()[0];
|
||||
for (Node* child : graph.GetMutableConsumerNodes(cast_output->Name())) {
|
||||
if (graph_utils::CanRemoveNode(graph, *parent, logger)) {
|
||||
NodeArg* cast_output = parent->MutableOutputDefs()[0];
|
||||
std::vector<Node*> children = graph.GetMutableConsumerNodes(cast_output->Name());
|
||||
if (children.size() == 1) {
|
||||
Node* child = children[0];
|
||||
if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) {
|
||||
if (IsCastTo(child, TensorProto::FLOAT16)) {
|
||||
// The parent and child cancell out
|
||||
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << node->Name() << " and " << child->Name();
|
||||
RemoveCastNodesChain(graph, {node, child}, removed_nodes);
|
||||
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << parent->Name() << " and " << child->Name();
|
||||
RemoveCastNodesChain(graph, {parent, child}, removed_nodes);
|
||||
modified = true;
|
||||
} else if (IsCastTo(child, TensorProto::FLOAT)) {
|
||||
// Child is a duplicate of parent
|
||||
|
|
@ -204,6 +245,45 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node,
|
|||
modified = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
NodeArg* parent_input = parent->MutableInputDefs()[0];
|
||||
const Node* producer = graph.GetProducerNode(parent_input->Name());
|
||||
int producer_output_index = producer ? optimizer_utils::IndexOfNodeOutput(*producer, *parent_input) : -1;
|
||||
std::vector<Node*> new_consumers;
|
||||
for (Node* child : children) {
|
||||
if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) {
|
||||
if (IsCastTo(child, TensorProto::FLOAT16)) {
|
||||
// The parent and child cancell out
|
||||
// Remove the child node without effecting the other nodes.
|
||||
// move all the consumers to the producer.
|
||||
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name();
|
||||
NodeArg* child_output = child->MutableOutputDefs()[0];
|
||||
for (Node* consumer : graph.GetMutableConsumerNodes(child_output->Name())) {
|
||||
std::vector<NodeArg*>& consumer_inputs = consumer->MutableInputDefs();
|
||||
int output_index = optimizer_utils::IndexOfNodeOutput(*child, *child_output);
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *child_output);
|
||||
graph.RemoveEdge(child->Index(), consumer->Index(), output_index, input_index);
|
||||
std::replace(consumer_inputs.begin(), consumer_inputs.end(), child_output, parent_input);
|
||||
if (nullptr != producer) {
|
||||
graph.AddEdge(producer->Index(), consumer->Index(), producer_output_index, input_index);
|
||||
}
|
||||
new_consumers.push_back(consumer);
|
||||
}
|
||||
modified = true;
|
||||
removed_nodes.push_front(child->Index());
|
||||
} else if (IsCastTo(child, TensorProto::FLOAT)) {
|
||||
// Child is a duplicate of parent
|
||||
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name();
|
||||
RemoveCastNodesChain(graph, {child}, removed_nodes);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (new_consumers.size() > 0) {
|
||||
std::vector<Node*> consumers = graph.GetMutableConsumerNodes(parent_input->Name());
|
||||
std::copy(new_consumers.begin(), new_consumers.end(), back_inserter(consumers));
|
||||
graph.UpdateConsumerNodes(parent_input->Name(), consumers);
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
|
|
@ -216,11 +296,18 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node,
|
|||
// nodearg is traversed not more than once.
|
||||
static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node,
|
||||
NodeArgToConsumerMap& require_cast,
|
||||
NodeArgToConsumerMap& require_cast_fp32,
|
||||
std::unordered_set<NodeArg*>& require_type_change,
|
||||
std::deque<NodeIndex>& removed_nodes,
|
||||
size_t level) {
|
||||
Node* node = graph.GetMutableProducerNode(node_arg->Name());
|
||||
if (graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
|
||||
// If the Cast input feeds more than one node or the cast node feeds a graph output and at least one
|
||||
// node then it cannot propagate.
|
||||
size_t consumer_node_count = graph.GetConsumerNodes(node_arg->Name()).size();
|
||||
if (level < 2 && (consumer_node_count > 1 ||
|
||||
nullptr != node &&
|
||||
consumer_node_count > 0 &&
|
||||
graph.IsOutput(node_arg))) {
|
||||
require_cast[node_arg].push_back(dst_node);
|
||||
} else if (node == nullptr) {
|
||||
// The graph inputs don't have the producer nodes
|
||||
|
|
@ -244,18 +331,30 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node,
|
|||
// TODO: If the specified optimization is greater than 1 then insert a Cast to the
|
||||
// other output_def and still propagate FP16 cast up the graph.
|
||||
if (output_def != node_arg) {
|
||||
if (IsType(*output_def, TensorProto_DataType_FLOAT)) {
|
||||
if (IsType(*output_def, TensorProto_DataType_FLOAT) && graph.GetConsumerNodes(output_def->Name()).size() > 0) {
|
||||
require_cast[node_arg].push_back(dst_node);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (level >= 2) {
|
||||
for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) {
|
||||
if (nullptr != consumer && consumer != dst_node && consumer->OpType() != "Cast" &&
|
||||
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
|
||||
require_cast_fp32[node_arg].push_back(consumer);
|
||||
}
|
||||
}
|
||||
if (graph.IsOutput(node_arg)) {
|
||||
require_cast_fp32[node_arg] = std::vector<Node*>();
|
||||
}
|
||||
}
|
||||
for (NodeArg* node_input : node->MutableInputDefs()) {
|
||||
if (IsType(*node_input, TensorProto_DataType_FLOAT) &&
|
||||
require_cast.find(node_input) == require_cast.end() &&
|
||||
require_type_change.find(node_input) == require_type_change.end()) {
|
||||
SearchUpstream(graph, node_input, node, require_cast, require_type_change, removed_nodes, level);
|
||||
if (require_cast.find(node_input) == require_cast.end()) {
|
||||
SearchUpstream(graph, node_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level);
|
||||
if (require_cast.find(node_input) == require_cast.end() &&
|
||||
require_cast_fp32.find(node_input) == require_cast_fp32.end()) {
|
||||
require_type_change.insert(node_input);
|
||||
}
|
||||
}
|
||||
|
|
@ -318,17 +417,6 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
|
|||
}
|
||||
}
|
||||
|
||||
// ConcatNames
|
||||
// Collects all the names from the pointers of the objects stores in the container class C
|
||||
// the class should have a member functions returning a string (or a ref).
|
||||
template <typename C, typename T = typename C::value_type>
|
||||
static std::string ConcatNames(
|
||||
C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) {
|
||||
std::vector<std::string> names;
|
||||
std::transform(items.begin(), items.end(), back_inserter(names), f);
|
||||
return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; });
|
||||
}
|
||||
|
||||
// Change the elem_type of the given NodeArgs from FLOAT to FLOAT16.
|
||||
static void ChangeTypeToFP16(Graph& graph, std::unordered_set<NodeArg*>& require_type_change, bool is_forward, const logging::Logger& logger) {
|
||||
ONNX_NAMESPACE::TypeProto type_proto;
|
||||
|
|
@ -401,25 +489,30 @@ static bool PropagateBackwards(Graph& graph, Node* node,
|
|||
bool modified = false;
|
||||
ORT_ENFORCE(nullptr != node);
|
||||
NodeArgToConsumerMap require_cast;
|
||||
NodeArgToConsumerMap require_cast_fp32;
|
||||
NodeArg* cast_input = node->MutableInputDefs()[0];
|
||||
const Node* cast_input_producer = graph.GetProducerNode(cast_input->Name()); // nullptr for graph outputs
|
||||
// If the Cast input feeds more than one node or the cast node feeds a graph output and at least one
|
||||
// node then it cannot propagate.
|
||||
size_t consumer_node_count = graph.GetConsumerNodes(cast_input->Name()).size();
|
||||
if (consumer_node_count > 1 ||
|
||||
(nullptr != cast_input_producer && graph.GetNodeOutputsInGraphOutputs(*cast_input_producer).size() > 0 && consumer_node_count > 0)) {
|
||||
return modified;
|
||||
std::unordered_set<NodeArg*> require_type_change;
|
||||
if (graph.IsOutput(cast_input)) {
|
||||
return false;
|
||||
}
|
||||
std::unordered_set<NodeArg*> require_type_change = {cast_input};
|
||||
SearchUpstream(graph, cast_input, node, require_cast, require_type_change, removed_nodes, level);
|
||||
if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end()) {
|
||||
SearchUpstream(graph, cast_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level);
|
||||
if (require_cast_fp32.empty()) {
|
||||
require_type_change.insert(cast_input);
|
||||
}
|
||||
// TODO need a huristic when to insert FP32 Cast
|
||||
if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end() /* && require_cast.size() >= require_cast_fp32.size() */) {
|
||||
// Remove Cast operation
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name();
|
||||
if (require_cast_fp32.size() > 0) {
|
||||
InsertCastNodes(graph, require_cast_fp32, false, removed_nodes);
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted FP32 Cast nodes "
|
||||
<< ConcatNames<NodeArgToConsumerMap>(require_cast_fp32, GetName);
|
||||
}
|
||||
RemoveCastNodesChain(graph, {node}, removed_nodes);
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name();
|
||||
InsertCastNodes(graph, require_cast, true, removed_nodes);
|
||||
ChangeTypeToFP16(graph, require_type_change, false, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted Cast nodes "
|
||||
<< ConcatNames<NodeArgToConsumerMap>(require_cast, GetName);
|
||||
ChangeTypeToFP16(graph, require_type_change, false, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Changed the type from float to float16 : "
|
||||
<< ConcatNames<std::unordered_set<NodeArg*>>(require_type_change);
|
||||
modified = true;
|
||||
|
|
@ -585,6 +678,7 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
|
|||
std::vector<Node*> casts; // Cast nodes to propagate.
|
||||
std::vector<NodeArg*>& outputs = node->MutableOutputDefs();
|
||||
std::unordered_set<NodeArg*> require_type_change;
|
||||
NodeArgToConsumerMap non_cast_consumers_map;
|
||||
// TODO Here we require the all floating point outputs are consumer by an immediate
|
||||
// child cast node.
|
||||
for (auto iter = outputs.begin(); iter != outputs.end() && all_float_outputs_have_casts; ++iter) {
|
||||
|
|
@ -594,22 +688,28 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
|
|||
}
|
||||
has_float_outputs = true;
|
||||
std::vector<Node*> consumers = graph.GetMutableConsumerNodes(output->Name());
|
||||
for (auto node_iter = consumers.begin(); node_iter != consumers.end() && all_float_outputs_have_casts; ++node_iter) {
|
||||
for (auto node_iter = consumers.begin(); node_iter != consumers.end() && (level >= 2 || all_float_outputs_have_casts); ++node_iter) {
|
||||
Node* consumer = *node_iter;
|
||||
if (nullptr != consumer &&
|
||||
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end() &&
|
||||
IsCastTo(consumer, TensorProto::FLOAT16)) {
|
||||
casts.push_back(consumer);
|
||||
continue;
|
||||
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
|
||||
if (IsCastTo(consumer, TensorProto::FLOAT16)) {
|
||||
casts.push_back(consumer);
|
||||
continue;
|
||||
} else {
|
||||
non_cast_consumers_map[output].push_back(consumer);
|
||||
}
|
||||
}
|
||||
all_float_outputs_have_casts = false;
|
||||
}
|
||||
require_type_change.insert(output);
|
||||
if (non_cast_consumers_map.empty()) {
|
||||
require_type_change.insert(output);
|
||||
}
|
||||
}
|
||||
if (has_float_outputs && all_float_outputs_have_casts && casts.size() > 1) {
|
||||
if (has_float_outputs && (level >= 2 || all_float_outputs_have_casts) && casts.size() > 1) {
|
||||
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Removed Cast nodes "
|
||||
<< ConcatNames<std::vector<Node*>>(casts)
|
||||
<< " feeding the same compute node " << node->Name();
|
||||
<< " feeding from the same compute node " << node->Name();
|
||||
InsertCastNodes(graph, non_cast_consumers_map, false, removed_nodes);
|
||||
for (Node* cast : casts) {
|
||||
RemoveCastNodesChain(graph, {cast}, removed_nodes);
|
||||
}
|
||||
|
|
@ -623,6 +723,8 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
|
|||
ChangeTypeToFP16(graph, require_type_change, false, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to "
|
||||
<< ConcatNames<NodeArgToConsumerMap>(node_args_map, GetName);
|
||||
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted FP32 Cast node to "
|
||||
<< ConcatNames<NodeArgToConsumerMap>(non_cast_consumers_map, GetName);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -705,16 +807,6 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
|
|||
}
|
||||
}
|
||||
|
||||
// Propagate FP16 Casts backward
|
||||
for (auto node_index : node_topology_list) {
|
||||
Node* node = graph.GetNode(node_index);
|
||||
if (nullptr != node &&
|
||||
std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() &&
|
||||
IsCastTo(node, TensorProto::FLOAT16)) {
|
||||
local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate FP16 Casts from outputs to inputs
|
||||
for (auto node_index : node_topology_list) {
|
||||
Node* node = graph.GetNode(node_index);
|
||||
|
|
@ -724,6 +816,16 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
|
|||
}
|
||||
}
|
||||
|
||||
// Propagate FP16 Casts
|
||||
for (auto node_index : node_topology_list) {
|
||||
Node* node = graph.GetNode(node_index);
|
||||
if (nullptr != node &&
|
||||
std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() &&
|
||||
IsCastTo(node, TensorProto::FLOAT16)) {
|
||||
local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate FP32 Casts from inputs to outputs
|
||||
for (auto node_index : node_topology_list) {
|
||||
Node* node = graph.GetNode(node_index);
|
||||
|
|
@ -755,7 +857,8 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
|
|||
|
||||
LOGS(logger, INFO) << "Nodes Converted to FP16:";
|
||||
std::for_each(converted_node_names.begin(), converted_node_names.end(), [removed_node_names, logger](std::string name) {
|
||||
if (removed_node_names.find(name) == removed_node_names.end()) { LOGS(logger, INFO) << name; } });
|
||||
if (removed_node_names.find(name) == removed_node_names.end() && inserted_node_names.find(name) == inserted_node_names.end()) {
|
||||
LOGS(logger, INFO) << name; } });
|
||||
}
|
||||
inserted_node_names.clear();
|
||||
converted_node_names.clear();
|
||||
|
|
|
|||
|
|
@ -4182,7 +4182,30 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
|
|||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 1, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul.onnx", 2, allow_matmul},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}};
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 1, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 1, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 2, allow_matmul},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs.onnx", 1, allow_matmul, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast.onnx", 1, allow_matmul_transpose, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 3, allow_matmul_transpose, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 3, allow_matmul_transpose, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 3, allow_matmul_transpose, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}};
|
||||
|
||||
// Create a temporary directory, which will be deleted automatically, to save/load the transformed models.
|
||||
TemporaryDirectory temp_dir{ORT_TSTR("propagate_casts_test_output_dir")};
|
||||
|
|
|
|||
|
|
@ -135,57 +135,71 @@ def do_cast_inputs(input_0, input_1, nodes):
|
|||
to = input_cast_type)])
|
||||
return "cast_"+input_0, "cast_"+input_1
|
||||
def do_transpose_inputs(input_0, input_1, nodes):
|
||||
nodes.extend([helper.make_node("Transpose", [input_0], ["transpose_"+input_0], "Transpose_0"),
|
||||
helper.make_node("Transpose", [input_1], ["transpose_"+input_1], "Transpose_1")])
|
||||
return "transpose_"+input_0, "transpose_"+input_1
|
||||
nodes.extend([helper.make_node("Transpose", [input_0], ["input_transpose_0"], "Transpose_0"),
|
||||
helper.make_node("Transpose", [input_1], ["input_transpose_1"], "Transpose_1")])
|
||||
return "input_transpose_0", "input_transpose_1"
|
||||
|
||||
def do_cast_product(product, nodes):
|
||||
nodes.append(helper.make_node(
|
||||
nodes.insert(1,helper.make_node(
|
||||
"Cast",
|
||||
[product],
|
||||
["cast" + product],
|
||||
["product_cast"],
|
||||
"Cast_2",
|
||||
to = TensorProto.FLOAT16))
|
||||
return "cast_"+product
|
||||
return "product_cast"
|
||||
|
||||
def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
def do_transpose_product(product, nodes):
|
||||
if transpose_product:
|
||||
nodes.append(helper.make_node("Transpose", [product], ["product_transpose"], "Transpose_2"))
|
||||
return "product_transpose"
|
||||
|
||||
def do_cast_sum(sum, nodes, type):
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
[sum],
|
||||
["cast_sum"],
|
||||
"Cast_3",
|
||||
to = type))
|
||||
return "cast_sum"
|
||||
|
||||
def do_cast_input2(input_2, nodes, type):
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
[input_2],
|
||||
["cast_"+input_2],
|
||||
"Cast_4",
|
||||
to = type))
|
||||
return "cast_"+input_2
|
||||
|
||||
def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2, transpose_inputs_before_cast=False):
|
||||
input_0 = "input_0"
|
||||
input_1 = "input_1"
|
||||
product = "product"
|
||||
nodes = []
|
||||
if transpose_inputs_before_cast:
|
||||
if transpose_inputs:
|
||||
input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes)
|
||||
if cast_inputs:
|
||||
input_0, input_1 = do_cast_inputs(input_0, input_1, nodes)
|
||||
else:
|
||||
if cast_inputs:
|
||||
input_0, input_1 = do_cast_inputs(input_0, input_1, nodes)
|
||||
if transpose_inputs:
|
||||
input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes)
|
||||
nodes.append(helper.make_node(
|
||||
"MatMul",
|
||||
["input_transpose_0" if transpose_inputs else ("cast_input_0" if cast_inputs else "input_0"),
|
||||
"input_transpose_1" if transpose_inputs else ("cast_input_1" if cast_inputs else "input_1")],
|
||||
["product"],
|
||||
[input_0,
|
||||
input_1],
|
||||
[product],
|
||||
"MatMul_0")
|
||||
]
|
||||
)
|
||||
if transpose_product:
|
||||
product = do_transpose_product(product, nodes)
|
||||
|
||||
if cast_product:
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
["product_transpose" if transpose_product else "product"],
|
||||
["product_cast"],
|
||||
"Cast_2",
|
||||
to = TensorProto.FLOAT16))
|
||||
product = do_cast_product(product, nodes)
|
||||
|
||||
if cast_inputs:
|
||||
input_cast_type = TensorProto.FLOAT
|
||||
nodes.extend([helper.make_node(
|
||||
"Cast",
|
||||
["input_0"],
|
||||
["cast_input_0"],
|
||||
"Cast_0",
|
||||
to = TensorProto.FLOAT),
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
["input_1"],
|
||||
["cast_input_1"],
|
||||
"Cast_1",
|
||||
to = TensorProto.FLOAT)])
|
||||
|
||||
if transpose_inputs:
|
||||
nodes.extend([helper.make_node("Transpose", ["cast_input_0" if cast_inputs else "input_0"], ["input_transpose_0"], "Transpose_0"),
|
||||
helper.make_node("Transpose", ["cast_input_1" if cast_inputs else "input_1"], ["input_transpose_1"], "Transpose_1")])
|
||||
|
||||
if transpose_product:
|
||||
nodes.append(helper.make_node("Transpose", ["product"], ["product_transpose"], "Transpose_2"))
|
||||
output = product
|
||||
|
||||
input_type = TensorProto.FLOAT16 if cast_inputs else TensorProto.FLOAT
|
||||
output_type = flip_type(cast_sum, flip_type(cast_product, flip_type(cast_inputs, input_type)))
|
||||
|
|
@ -196,28 +210,21 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc
|
|||
"input_1", input_type, ['N', 'N'])
|
||||
]
|
||||
if insert_add:
|
||||
|
||||
input_2 = "input_2"
|
||||
add_input_type = flip_type(True, input_type) if cast_inputs != cast_product else input_type
|
||||
add_input_type = flip_type(cast_input2, add_input_type)
|
||||
inputs.append(helper.make_tensor_value_info("input_2", add_input_type, ['N', 'N']))
|
||||
nodes.append(helper.make_node("Add", ["product_cast" if cast_product else ("product_transpose" if transpose_product else "product"), "cast_input_2" if cast_input2 else "input_2"], ["sum"], "Add_0"))
|
||||
if cast_sum:
|
||||
input2_cast_type = flip_type(True, flip_type(cast_input2, add_input_type))
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
["sum"],
|
||||
["cast_sum"],
|
||||
"Cast_3",
|
||||
to = input2_cast_type))
|
||||
inputs.append(helper.make_tensor_value_info(input_2, add_input_type, ['N', 'N']))
|
||||
add_output = "sum"
|
||||
if cast_input2:
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
["input_2"],
|
||||
["cast_input_2"],
|
||||
"Cast_4",
|
||||
to = flip_type(True, add_input_type)))
|
||||
input_2 = do_cast_input2(input_2, nodes, flip_type(True, add_input_type))
|
||||
nodes.append(helper.make_node("Add", [product, input_2], [add_output], "Add_0"))
|
||||
if cast_sum:
|
||||
add_output = do_cast_sum(add_output, nodes, flip_type(not cast_input2, add_input_type))
|
||||
output = add_output
|
||||
outputs = [
|
||||
helper.make_tensor_value_info(
|
||||
"cast_sum" if cast_sum else "sum" if insert_add else ("product_cast" if cast_product else ("product_transpose" if transpose_product else "product")), output_type, ['N', 'N'])
|
||||
output, output_type, ['N', 'N'])
|
||||
]
|
||||
|
||||
save(model_path + ("_transpose_inputs" if transpose_inputs else "") +
|
||||
|
|
@ -229,11 +236,12 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc
|
|||
nodes, inputs, outputs, [])
|
||||
|
||||
def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul):
|
||||
def do_transpose(output_0, output_1, nodes):
|
||||
nodes.extend([helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"),
|
||||
helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")])
|
||||
def do_transpose(output_0, output_1, transpose, nodes):
|
||||
nodes.append(helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"))
|
||||
output_0 = "transpose_0_"+output_0
|
||||
output_1 ="transpose_1_"+output_1
|
||||
if transpose > 1:
|
||||
nodes.append(helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1"))
|
||||
output_1 ="transpose_1_"+output_1
|
||||
return output_0, output_1
|
||||
input_type = TensorProto.FLOAT
|
||||
input_0 = "input_0"
|
||||
|
|
@ -262,9 +270,16 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
|
|||
"MatMul_1"))
|
||||
outputs.append(helper.make_tensor_value_info(
|
||||
"second_"+output, input_type, ['M', 'N']))
|
||||
|
||||
if transpose and transpose_before_cast:
|
||||
output_0, output_1 = do_transpose(output_0, output_1, nodes)
|
||||
if add_products:
|
||||
nodes.append(helper.make_node(
|
||||
"Add",
|
||||
[output, "second_"+output],
|
||||
["sum"],
|
||||
"Add_0"))
|
||||
outputs.append(helper.make_tensor_value_info(
|
||||
"sum", input_type, ['M', 'N']))
|
||||
if transpose > 0 and transpose_before_cast:
|
||||
output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes)
|
||||
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
|
|
@ -283,8 +298,8 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
|
|||
to = TensorProto.FLOAT16))
|
||||
output_1 = "cast_1_"+output_1
|
||||
|
||||
if transpose and not transpose_before_cast:
|
||||
output_0, output_1 = do_transpose(output_0, output_1, nodes)
|
||||
if transpose > 0 and not transpose_before_cast:
|
||||
output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes)
|
||||
|
||||
outputs.extend([
|
||||
helper.make_tensor_value_info(
|
||||
|
|
@ -292,8 +307,10 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
|
|||
helper.make_tensor_value_info(
|
||||
output_1, flip_type(second_matmul, input_type), ['M', 'N'])
|
||||
])
|
||||
model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose else ""
|
||||
model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose > 0 else ""
|
||||
model_path += "_transpose" if transpose > 1 else ""
|
||||
model_path += "_second_matmul" if second_matmul else ""
|
||||
model_path += "_add_products" if add_products else ""
|
||||
save(model_path, nodes, inputs, outputs, [])
|
||||
|
||||
for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)):
|
||||
|
|
@ -305,7 +322,9 @@ for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add,
|
|||
gen_fuse_sibling_casts("fuse_sibling_casts")
|
||||
gen_fuse_back2back_casts("fuse_back2back_casts")
|
||||
|
||||
for (transpose, transpose_before_cast, second_matmul) in list(itertools.product([False, True], repeat=3)):
|
||||
for (transpose, transpose_before_cast, second_matmul, add_products) in list(itertools.product([0,1,2], [False, True], [False, True], [False, True])):
|
||||
if not transpose and transpose_before_cast:
|
||||
continue
|
||||
if not second_matmul and add_products:
|
||||
continue
|
||||
gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in a new issue