Add level two optimizations for constant propagation transformation. (#7410)

* Made the python script generating the testcases modular.

* Modified RemoveBackToBackCasts function to remove cast even if the parent node has other consumers.

* Modified InsertCastNodes to update the graph consistently for other functions to work.

* Moved ConcatNames function to the top.

* PropagateBackward/SearchUpstream and PropagateFP16CastsFromOutputsToInputs insert FP32 casts if the level >1 in order to propagate FP16 casts backwards.

* Added new testcases for level two setting.
This commit is contained in:
satyajandhyala 2021-04-23 13:25:54 -07:00 committed by GitHub
parent f1c3f3fcc1
commit 979d63159b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
77 changed files with 296 additions and 151 deletions

View file

@ -11,10 +11,21 @@ using namespace onnxruntime::common;
namespace onnxruntime {
// NodeArg to Select consumer node map.
typedef std::unordered_map<NodeArg*, std::vector<Node*>> NodeArgToConsumerMap;
static std::string GetName(const std::pair<const NodeArg*, std::vector<Node*>>& p) {
return p.first->Name();
};
// ConcatNames
// Collects all the names from the pointers of the objects stores in the container class C
// the class should have a member functions returning a string (or a ref).
template <typename C, typename T = typename C::value_type>
static std::string ConcatNames(
C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) {
std::vector<std::string> names;
std::transform(items.begin(), items.end(), back_inserter(names), f);
return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; });
}
static std::string GetName(const std::pair<const NodeArg*, std::vector<Node*>>& p) {
return p.first->Name() + " feeding " + ConcatNames(p.second) + "; ";
};
// The collection fp16_allow_ops, specifies for a given propagate_cast_ops level, a vector of node op_types that
// the code is allowed to propage Cast operations cross. The user may specify a custom list of optypes using level 0.
// The opcodes are split into multiple levels. Cast propagation is done based on the level. Level 2 op code
@ -65,7 +76,7 @@ static Status InsertCastNodes(Graph& graph,
// data_type is the data type of the Cast output.
TensorProto_DataType data_type = is_fp16 ? TensorProto_DataType_FLOAT16 : TensorProto_DataType_FLOAT;
TypeProto type_proto;
bool is_node_arg_cast_output = IsType(*node_arg, data_type);
bool is_node_arg_cast_output = IsType(*node_arg, data_type); // true if the producer node_arg is being replaced
TensorProto_DataType new_node_arg_data_type = data_type;
if (is_node_arg_cast_output) {
@ -94,31 +105,59 @@ static Status InsertCastNodes(Graph& graph,
inserted_node_names.insert(cast.Name());
Node* producer = graph.GetMutableProducerNode(node_arg->Name());
std::vector<Node*> consumers = graph.GetMutableConsumerNodes(node_arg->Name());
int output_index = (nullptr != producer) ? optimizer_utils::IndexOfNodeOutput(*producer, *node_arg) : -1;
// Update consumers of node_arg to use the output of the cast node
int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output);
for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) {
if (nullptr != consumer && std::find(nodes.begin(), nodes.end(), consumer) != nodes.end() &&
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
auto& consumer_inputs = consumer->MutableInputDefs();
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
if (nullptr != producer) {
std::vector<Node*> other_nodes;
if (nullptr != producer) {
int output_index = optimizer_utils::IndexOfNodeOutput(*producer, *node_arg);
for (Node* consumer : consumers) {
// Removed the edges to the consumers, getting casted.
if (std::find(nodes.begin(), nodes.end(), consumer) != nodes.end()) {
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
graph.RemoveEdge(producer->Index(), consumer->Index(), output_index, input_index);
} else {
other_nodes.push_back(consumer);
}
std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output);
graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index);
}
if (is_node_arg_cast_output) {
// Replace the node_arg with the new_node_arg in the producer outputs
auto& producer_outputs = producer->MutableOutputDefs();
std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input);
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
}
}
if (nullptr != producer) {
auto& producer_outputs = producer->MutableOutputDefs();
// The following replacement is necessary in case where the output of the cast node is original
// output of the producer, for example the original output of the producer may be the graph output.
std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input);
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
int input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input);
graph.AddEdge(producer->Index(), cast.Index(), output_index, input_index);
// Update consumers of node_arg to use the output of the cast node
int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output);
for (Node* consumer : consumers) {
if (std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
if (std::find(nodes.begin(), nodes.end(), consumer) == nodes.end()) {
// Consumers not getting casted need to replace input-def if the producer's output-def is changed
if (is_node_arg_cast_output) {
auto& consumer_inputs = consumer->MutableInputDefs();
std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_output, &cast_input);
}
} else {
// Consumers getting casted need to get new edges from the new cast node..
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
if (!is_node_arg_cast_output) {
auto& consumer_inputs = consumer->MutableInputDefs();
std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output);
}
graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index);
}
}
}
// Complete the input/output connections to the new cast node, and update the graph.
other_nodes.push_back(&cast);
graph.UpdateProducerNode(cast_output.Name(), cast.Index());
graph.UpdateConsumerNodes(cast_output.Name(), nodes);
graph.UpdateConsumerNodes(cast_input.Name(), other_nodes);
if (nullptr != producer) {
int cast_input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input);
int output_index = optimizer_utils::IndexOfNodeOutput(*producer, cast_input);
graph.AddEdge(producer->Index(), cast.Index(), output_index, cast_input_index);
if (is_node_arg_cast_output) {
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
}
}
}
return Status::OK();
}
@ -138,7 +177,7 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
if (producer) {
if (graph.IsOutput(cast_output)) {
// cast_output is a graph output. Replace the cast node with an Identity operator unless node
// has other outputs.
// has no other outputs.
if (producer->GetOutputEdgesCount() == 1) {
int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input);
graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index);
@ -146,11 +185,11 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
std::replace(outputs.begin(), outputs.end(), cast_input, cast_output);
graph.UpdateProducerNode(cast_output->Name(), producer->Index());
} else {
(void) graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"),
"Identity",
"Created as a place-holder for a graph output",
{cast_input},
{cast_output});
(void)graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"),
"Identity",
"Created as a place-holder for a graph output",
{cast_input},
{cast_output});
}
}
}
@ -183,19 +222,21 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
// and the second cast is from FLOAT to FLOAT16.
// Condition: The parent cast should have only one output
// The inputs is Cast to FLOAT16
static bool RemoveBackToBackCasts(Graph& graph, Node* node,
static bool RemoveBackToBackCasts(Graph& graph, Node* parent,
std::deque<NodeIndex>& removed_nodes,
const logging::Logger& logger) {
ORT_ENFORCE(IsCastTo(node, TensorProto::FLOAT));
ORT_ENFORCE(IsCastTo(parent, TensorProto::FLOAT));
bool modified = false;
if (graph_utils::CanRemoveNode(graph, *node, logger)) {
NodeArg* cast_output = node->MutableOutputDefs()[0];
for (Node* child : graph.GetMutableConsumerNodes(cast_output->Name())) {
if (graph_utils::CanRemoveNode(graph, *parent, logger)) {
NodeArg* cast_output = parent->MutableOutputDefs()[0];
std::vector<Node*> children = graph.GetMutableConsumerNodes(cast_output->Name());
if (children.size() == 1) {
Node* child = children[0];
if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) {
if (IsCastTo(child, TensorProto::FLOAT16)) {
// The parent and child cancell out
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << node->Name() << " and " << child->Name();
RemoveCastNodesChain(graph, {node, child}, removed_nodes);
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << parent->Name() << " and " << child->Name();
RemoveCastNodesChain(graph, {parent, child}, removed_nodes);
modified = true;
} else if (IsCastTo(child, TensorProto::FLOAT)) {
// Child is a duplicate of parent
@ -204,6 +245,45 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node,
modified = true;
}
}
} else {
NodeArg* parent_input = parent->MutableInputDefs()[0];
const Node* producer = graph.GetProducerNode(parent_input->Name());
int producer_output_index = producer ? optimizer_utils::IndexOfNodeOutput(*producer, *parent_input) : -1;
std::vector<Node*> new_consumers;
for (Node* child : children) {
if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) {
if (IsCastTo(child, TensorProto::FLOAT16)) {
// The parent and child cancell out
// Remove the child node without effecting the other nodes.
// move all the consumers to the producer.
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name();
NodeArg* child_output = child->MutableOutputDefs()[0];
for (Node* consumer : graph.GetMutableConsumerNodes(child_output->Name())) {
std::vector<NodeArg*>& consumer_inputs = consumer->MutableInputDefs();
int output_index = optimizer_utils::IndexOfNodeOutput(*child, *child_output);
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *child_output);
graph.RemoveEdge(child->Index(), consumer->Index(), output_index, input_index);
std::replace(consumer_inputs.begin(), consumer_inputs.end(), child_output, parent_input);
if (nullptr != producer) {
graph.AddEdge(producer->Index(), consumer->Index(), producer_output_index, input_index);
}
new_consumers.push_back(consumer);
}
modified = true;
removed_nodes.push_front(child->Index());
} else if (IsCastTo(child, TensorProto::FLOAT)) {
// Child is a duplicate of parent
LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name();
RemoveCastNodesChain(graph, {child}, removed_nodes);
modified = true;
}
}
}
if (new_consumers.size() > 0) {
std::vector<Node*> consumers = graph.GetMutableConsumerNodes(parent_input->Name());
std::copy(new_consumers.begin(), new_consumers.end(), back_inserter(consumers));
graph.UpdateConsumerNodes(parent_input->Name(), consumers);
}
}
}
return modified;
@ -216,11 +296,18 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node,
// nodearg is traversed not more than once.
static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node,
NodeArgToConsumerMap& require_cast,
NodeArgToConsumerMap& require_cast_fp32,
std::unordered_set<NodeArg*>& require_type_change,
std::deque<NodeIndex>& removed_nodes,
size_t level) {
Node* node = graph.GetMutableProducerNode(node_arg->Name());
if (graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
// If the Cast input feeds more than one node or the cast node feeds a graph output and at least one
// node then it cannot propagate.
size_t consumer_node_count = graph.GetConsumerNodes(node_arg->Name()).size();
if (level < 2 && (consumer_node_count > 1 ||
nullptr != node &&
consumer_node_count > 0 &&
graph.IsOutput(node_arg))) {
require_cast[node_arg].push_back(dst_node);
} else if (node == nullptr) {
// The graph inputs don't have the producer nodes
@ -244,18 +331,30 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node,
// TODO: If the specified optimization is greater than 1 then insert a Cast to the
// other output_def and still propagate FP16 cast up the graph.
if (output_def != node_arg) {
if (IsType(*output_def, TensorProto_DataType_FLOAT)) {
if (IsType(*output_def, TensorProto_DataType_FLOAT) && graph.GetConsumerNodes(output_def->Name()).size() > 0) {
require_cast[node_arg].push_back(dst_node);
return;
}
}
}
if (level >= 2) {
for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) {
if (nullptr != consumer && consumer != dst_node && consumer->OpType() != "Cast" &&
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
require_cast_fp32[node_arg].push_back(consumer);
}
}
if (graph.IsOutput(node_arg)) {
require_cast_fp32[node_arg] = std::vector<Node*>();
}
}
for (NodeArg* node_input : node->MutableInputDefs()) {
if (IsType(*node_input, TensorProto_DataType_FLOAT) &&
require_cast.find(node_input) == require_cast.end() &&
require_type_change.find(node_input) == require_type_change.end()) {
SearchUpstream(graph, node_input, node, require_cast, require_type_change, removed_nodes, level);
if (require_cast.find(node_input) == require_cast.end()) {
SearchUpstream(graph, node_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level);
if (require_cast.find(node_input) == require_cast.end() &&
require_cast_fp32.find(node_input) == require_cast_fp32.end()) {
require_type_change.insert(node_input);
}
}
@ -318,17 +417,6 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
}
}
// ConcatNames
// Collects all the names from the pointers of the objects stores in the container class C
// the class should have a member functions returning a string (or a ref).
template <typename C, typename T = typename C::value_type>
static std::string ConcatNames(
C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) {
std::vector<std::string> names;
std::transform(items.begin(), items.end(), back_inserter(names), f);
return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; });
}
// Change the elem_type of the given NodeArgs from FLOAT to FLOAT16.
static void ChangeTypeToFP16(Graph& graph, std::unordered_set<NodeArg*>& require_type_change, bool is_forward, const logging::Logger& logger) {
ONNX_NAMESPACE::TypeProto type_proto;
@ -401,25 +489,30 @@ static bool PropagateBackwards(Graph& graph, Node* node,
bool modified = false;
ORT_ENFORCE(nullptr != node);
NodeArgToConsumerMap require_cast;
NodeArgToConsumerMap require_cast_fp32;
NodeArg* cast_input = node->MutableInputDefs()[0];
const Node* cast_input_producer = graph.GetProducerNode(cast_input->Name()); // nullptr for graph outputs
// If the Cast input feeds more than one node or the cast node feeds a graph output and at least one
// node then it cannot propagate.
size_t consumer_node_count = graph.GetConsumerNodes(cast_input->Name()).size();
if (consumer_node_count > 1 ||
(nullptr != cast_input_producer && graph.GetNodeOutputsInGraphOutputs(*cast_input_producer).size() > 0 && consumer_node_count > 0)) {
return modified;
std::unordered_set<NodeArg*> require_type_change;
if (graph.IsOutput(cast_input)) {
return false;
}
std::unordered_set<NodeArg*> require_type_change = {cast_input};
SearchUpstream(graph, cast_input, node, require_cast, require_type_change, removed_nodes, level);
if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end()) {
SearchUpstream(graph, cast_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level);
if (require_cast_fp32.empty()) {
require_type_change.insert(cast_input);
}
// TODO need a huristic when to insert FP32 Cast
if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end() /* && require_cast.size() >= require_cast_fp32.size() */) {
// Remove Cast operation
LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name();
if (require_cast_fp32.size() > 0) {
InsertCastNodes(graph, require_cast_fp32, false, removed_nodes);
LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted FP32 Cast nodes "
<< ConcatNames<NodeArgToConsumerMap>(require_cast_fp32, GetName);
}
RemoveCastNodesChain(graph, {node}, removed_nodes);
LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name();
InsertCastNodes(graph, require_cast, true, removed_nodes);
ChangeTypeToFP16(graph, require_type_change, false, logger);
LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted Cast nodes "
<< ConcatNames<NodeArgToConsumerMap>(require_cast, GetName);
ChangeTypeToFP16(graph, require_type_change, false, logger);
LOGS(logger, VERBOSE) << "PropagateBackwards: Changed the type from float to float16 : "
<< ConcatNames<std::unordered_set<NodeArg*>>(require_type_change);
modified = true;
@ -585,6 +678,7 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
std::vector<Node*> casts; // Cast nodes to propagate.
std::vector<NodeArg*>& outputs = node->MutableOutputDefs();
std::unordered_set<NodeArg*> require_type_change;
NodeArgToConsumerMap non_cast_consumers_map;
// TODO Here we require the all floating point outputs are consumer by an immediate
// child cast node.
for (auto iter = outputs.begin(); iter != outputs.end() && all_float_outputs_have_casts; ++iter) {
@ -594,22 +688,28 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
}
has_float_outputs = true;
std::vector<Node*> consumers = graph.GetMutableConsumerNodes(output->Name());
for (auto node_iter = consumers.begin(); node_iter != consumers.end() && all_float_outputs_have_casts; ++node_iter) {
for (auto node_iter = consumers.begin(); node_iter != consumers.end() && (level >= 2 || all_float_outputs_have_casts); ++node_iter) {
Node* consumer = *node_iter;
if (nullptr != consumer &&
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end() &&
IsCastTo(consumer, TensorProto::FLOAT16)) {
casts.push_back(consumer);
continue;
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
if (IsCastTo(consumer, TensorProto::FLOAT16)) {
casts.push_back(consumer);
continue;
} else {
non_cast_consumers_map[output].push_back(consumer);
}
}
all_float_outputs_have_casts = false;
}
require_type_change.insert(output);
if (non_cast_consumers_map.empty()) {
require_type_change.insert(output);
}
}
if (has_float_outputs && all_float_outputs_have_casts && casts.size() > 1) {
if (has_float_outputs && (level >= 2 || all_float_outputs_have_casts) && casts.size() > 1) {
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Removed Cast nodes "
<< ConcatNames<std::vector<Node*>>(casts)
<< " feeding the same compute node " << node->Name();
<< " feeding from the same compute node " << node->Name();
InsertCastNodes(graph, non_cast_consumers_map, false, removed_nodes);
for (Node* cast : casts) {
RemoveCastNodesChain(graph, {cast}, removed_nodes);
}
@ -623,6 +723,8 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
ChangeTypeToFP16(graph, require_type_change, false, logger);
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to "
<< ConcatNames<NodeArgToConsumerMap>(node_args_map, GetName);
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted FP32 Cast node to "
<< ConcatNames<NodeArgToConsumerMap>(non_cast_consumers_map, GetName);
modified = true;
}
}
@ -705,16 +807,6 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
}
}
// Propagate FP16 Casts backward
for (auto node_index : node_topology_list) {
Node* node = graph.GetNode(node_index);
if (nullptr != node &&
std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() &&
IsCastTo(node, TensorProto::FLOAT16)) {
local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger);
}
}
// Propagate FP16 Casts from outputs to inputs
for (auto node_index : node_topology_list) {
Node* node = graph.GetNode(node_index);
@ -724,6 +816,16 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
}
}
// Propagate FP16 Casts
for (auto node_index : node_topology_list) {
Node* node = graph.GetNode(node_index);
if (nullptr != node &&
std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() &&
IsCastTo(node, TensorProto::FLOAT16)) {
local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger);
}
}
// Propagate FP32 Casts from inputs to outputs
for (auto node_index : node_topology_list) {
Node* node = graph.GetNode(node_index);
@ -755,7 +857,8 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
LOGS(logger, INFO) << "Nodes Converted to FP16:";
std::for_each(converted_node_names.begin(), converted_node_names.end(), [removed_node_names, logger](std::string name) {
if (removed_node_names.find(name) == removed_node_names.end()) { LOGS(logger, INFO) << name; } });
if (removed_node_names.find(name) == removed_node_names.end() && inserted_node_names.find(name) == inserted_node_names.end()) {
LOGS(logger, INFO) << name; } });
}
inserted_node_names.clear();
converted_node_names.clear();

View file

@ -4182,7 +4182,30 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 1, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul.onnx", 2, allow_matmul},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}};
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 1, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 1, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 2, allow_matmul},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs.onnx", 1, allow_matmul, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast.onnx", 1, allow_matmul_transpose, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 3, allow_matmul_transpose, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 3, allow_matmul_transpose, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 3, allow_matmul_transpose, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}};
// Create a temporary directory, which will be deleted automatically, to save/load the transformed models.
TemporaryDirectory temp_dir{ORT_TSTR("propagate_casts_test_output_dir")};

View file

@ -135,57 +135,71 @@ def do_cast_inputs(input_0, input_1, nodes):
to = input_cast_type)])
return "cast_"+input_0, "cast_"+input_1
def do_transpose_inputs(input_0, input_1, nodes):
nodes.extend([helper.make_node("Transpose", [input_0], ["transpose_"+input_0], "Transpose_0"),
helper.make_node("Transpose", [input_1], ["transpose_"+input_1], "Transpose_1")])
return "transpose_"+input_0, "transpose_"+input_1
nodes.extend([helper.make_node("Transpose", [input_0], ["input_transpose_0"], "Transpose_0"),
helper.make_node("Transpose", [input_1], ["input_transpose_1"], "Transpose_1")])
return "input_transpose_0", "input_transpose_1"
def do_cast_product(product, nodes):
nodes.append(helper.make_node(
nodes.insert(1,helper.make_node(
"Cast",
[product],
["cast" + product],
["product_cast"],
"Cast_2",
to = TensorProto.FLOAT16))
return "cast_"+product
return "product_cast"
def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2):
nodes = [
helper.make_node(
def do_transpose_product(product, nodes):
if transpose_product:
nodes.append(helper.make_node("Transpose", [product], ["product_transpose"], "Transpose_2"))
return "product_transpose"
def do_cast_sum(sum, nodes, type):
nodes.append(helper.make_node(
"Cast",
[sum],
["cast_sum"],
"Cast_3",
to = type))
return "cast_sum"
def do_cast_input2(input_2, nodes, type):
nodes.append(helper.make_node(
"Cast",
[input_2],
["cast_"+input_2],
"Cast_4",
to = type))
return "cast_"+input_2
def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2, transpose_inputs_before_cast=False):
input_0 = "input_0"
input_1 = "input_1"
product = "product"
nodes = []
if transpose_inputs_before_cast:
if transpose_inputs:
input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes)
if cast_inputs:
input_0, input_1 = do_cast_inputs(input_0, input_1, nodes)
else:
if cast_inputs:
input_0, input_1 = do_cast_inputs(input_0, input_1, nodes)
if transpose_inputs:
input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes)
nodes.append(helper.make_node(
"MatMul",
["input_transpose_0" if transpose_inputs else ("cast_input_0" if cast_inputs else "input_0"),
"input_transpose_1" if transpose_inputs else ("cast_input_1" if cast_inputs else "input_1")],
["product"],
[input_0,
input_1],
[product],
"MatMul_0")
]
)
if transpose_product:
product = do_transpose_product(product, nodes)
if cast_product:
nodes.append(helper.make_node(
"Cast",
["product_transpose" if transpose_product else "product"],
["product_cast"],
"Cast_2",
to = TensorProto.FLOAT16))
product = do_cast_product(product, nodes)
if cast_inputs:
input_cast_type = TensorProto.FLOAT
nodes.extend([helper.make_node(
"Cast",
["input_0"],
["cast_input_0"],
"Cast_0",
to = TensorProto.FLOAT),
helper.make_node(
"Cast",
["input_1"],
["cast_input_1"],
"Cast_1",
to = TensorProto.FLOAT)])
if transpose_inputs:
nodes.extend([helper.make_node("Transpose", ["cast_input_0" if cast_inputs else "input_0"], ["input_transpose_0"], "Transpose_0"),
helper.make_node("Transpose", ["cast_input_1" if cast_inputs else "input_1"], ["input_transpose_1"], "Transpose_1")])
if transpose_product:
nodes.append(helper.make_node("Transpose", ["product"], ["product_transpose"], "Transpose_2"))
output = product
input_type = TensorProto.FLOAT16 if cast_inputs else TensorProto.FLOAT
output_type = flip_type(cast_sum, flip_type(cast_product, flip_type(cast_inputs, input_type)))
@ -196,28 +210,21 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc
"input_1", input_type, ['N', 'N'])
]
if insert_add:
input_2 = "input_2"
add_input_type = flip_type(True, input_type) if cast_inputs != cast_product else input_type
add_input_type = flip_type(cast_input2, add_input_type)
inputs.append(helper.make_tensor_value_info("input_2", add_input_type, ['N', 'N']))
nodes.append(helper.make_node("Add", ["product_cast" if cast_product else ("product_transpose" if transpose_product else "product"), "cast_input_2" if cast_input2 else "input_2"], ["sum"], "Add_0"))
if cast_sum:
input2_cast_type = flip_type(True, flip_type(cast_input2, add_input_type))
nodes.append(helper.make_node(
"Cast",
["sum"],
["cast_sum"],
"Cast_3",
to = input2_cast_type))
inputs.append(helper.make_tensor_value_info(input_2, add_input_type, ['N', 'N']))
add_output = "sum"
if cast_input2:
nodes.append(helper.make_node(
"Cast",
["input_2"],
["cast_input_2"],
"Cast_4",
to = flip_type(True, add_input_type)))
input_2 = do_cast_input2(input_2, nodes, flip_type(True, add_input_type))
nodes.append(helper.make_node("Add", [product, input_2], [add_output], "Add_0"))
if cast_sum:
add_output = do_cast_sum(add_output, nodes, flip_type(not cast_input2, add_input_type))
output = add_output
outputs = [
helper.make_tensor_value_info(
"cast_sum" if cast_sum else "sum" if insert_add else ("product_cast" if cast_product else ("product_transpose" if transpose_product else "product")), output_type, ['N', 'N'])
output, output_type, ['N', 'N'])
]
save(model_path + ("_transpose_inputs" if transpose_inputs else "") +
@ -229,11 +236,12 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc
nodes, inputs, outputs, [])
def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul):
def do_transpose(output_0, output_1, nodes):
nodes.extend([helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"),
helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")])
def do_transpose(output_0, output_1, transpose, nodes):
nodes.append(helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"))
output_0 = "transpose_0_"+output_0
output_1 ="transpose_1_"+output_1
if transpose > 1:
nodes.append(helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1"))
output_1 ="transpose_1_"+output_1
return output_0, output_1
input_type = TensorProto.FLOAT
input_0 = "input_0"
@ -262,9 +270,16 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
"MatMul_1"))
outputs.append(helper.make_tensor_value_info(
"second_"+output, input_type, ['M', 'N']))
if transpose and transpose_before_cast:
output_0, output_1 = do_transpose(output_0, output_1, nodes)
if add_products:
nodes.append(helper.make_node(
"Add",
[output, "second_"+output],
["sum"],
"Add_0"))
outputs.append(helper.make_tensor_value_info(
"sum", input_type, ['M', 'N']))
if transpose > 0 and transpose_before_cast:
output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes)
nodes.append(helper.make_node(
"Cast",
@ -283,8 +298,8 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
to = TensorProto.FLOAT16))
output_1 = "cast_1_"+output_1
if transpose and not transpose_before_cast:
output_0, output_1 = do_transpose(output_0, output_1, nodes)
if transpose > 0 and not transpose_before_cast:
output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes)
outputs.extend([
helper.make_tensor_value_info(
@ -292,8 +307,10 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
helper.make_tensor_value_info(
output_1, flip_type(second_matmul, input_type), ['M', 'N'])
])
model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose else ""
model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose > 0 else ""
model_path += "_transpose" if transpose > 1 else ""
model_path += "_second_matmul" if second_matmul else ""
model_path += "_add_products" if add_products else ""
save(model_path, nodes, inputs, outputs, [])
for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)):
@ -305,7 +322,9 @@ for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add,
gen_fuse_sibling_casts("fuse_sibling_casts")
gen_fuse_back2back_casts("fuse_back2back_casts")
for (transpose, transpose_before_cast, second_matmul) in list(itertools.product([False, True], repeat=3)):
for (transpose, transpose_before_cast, second_matmul, add_products) in list(itertools.product([0,1,2], [False, True], [False, True], [False, True])):
if not transpose and transpose_before_cast:
continue
if not second_matmul and add_products:
continue
gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul)