From 979d63159bd9bbc98b0a7a2fd8738112731bc28c Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Fri, 23 Apr 2021 13:25:54 -0700 Subject: [PATCH] Add level two optimizations for constant propagation transformation. (#7410) * Made the python script generating the testcases modular. * Modified RemoveBackToBackCasts function to remove cast even if the parent node has other consumers. * Modified InsertCastNodes to update the graph consistently for other functions to work. * Moved ConcatNames function to the top. * PropagateBackward/SearchUpstream and PropagateFP16CastsFromOutputsToInputs insert FP32 casts if the level >1 in order to propagate FP16 casts backwards. * Added new testcases for level two setting. --- .../core/optimizer/propagate_cast_ops.cc | 267 ++++++++++++------ .../test/optimizer/graph_transform_test.cc | 25 +- .../propagate_cast/gen_propagate_cast.py | 155 +++++----- .../matmul_add_cast_input2_cast_sum.onnx | Bin 365 -> 365 bytes .../matmul_add_cast_inputs.onnx | Bin 373 -> 373 bytes .../matmul_add_cast_inputs_cast_input2.onnx | Bin 428 -> 428 bytes ..._add_cast_inputs_cast_input2_cast_sum.onnx | Bin 475 -> 475 bytes .../matmul_add_cast_inputs_cast_product.onnx | Bin 428 -> 428 bytes ..._cast_inputs_cast_product_cast_input2.onnx | Bin 483 -> 483 bytes ...uts_cast_product_cast_input2_cast_sum.onnx | Bin 530 -> 530 bytes ...add_cast_inputs_cast_product_cast_sum.onnx | Bin 475 -> 475 bytes .../matmul_add_cast_inputs_cast_sum.onnx | Bin 420 -> 420 bytes .../matmul_add_cast_product_cast_input2.onnx | Bin 373 -> 373 bytes ...add_cast_product_cast_input2_cast_sum.onnx | Bin 420 -> 420 bytes ...transpose_inputs_cast_input2_cast_sum.onnx | Bin 493 -> 493 bytes ...tmul_add_transpose_inputs_cast_inputs.onnx | Bin 501 -> 501 bytes ...nspose_inputs_cast_inputs_cast_input2.onnx | Bin 556 -> 556 bytes ...puts_cast_inputs_cast_input2_cast_sum.onnx | Bin 603 -> 603 bytes ...spose_inputs_cast_inputs_cast_product.onnx | Bin 556 -> 556 bytes ..._cast_inputs_cast_product_cast_input2.onnx | Bin 611 -> 611 bytes ...uts_cast_product_cast_input2_cast_sum.onnx | Bin 658 -> 658 bytes ...uts_cast_inputs_cast_product_cast_sum.onnx | Bin 603 -> 603 bytes ...transpose_inputs_cast_inputs_cast_sum.onnx | Bin 548 -> 548 bytes ...mul_add_transpose_inputs_cast_product.onnx | Bin 446 -> 446 bytes ...spose_inputs_cast_product_cast_input2.onnx | Bin 501 -> 501 bytes ...uts_cast_product_cast_input2_cast_sum.onnx | Bin 548 -> 548 bytes ...ranspose_inputs_cast_product_cast_sum.onnx | Bin 493 -> 493 bytes .../matmul_add_transpose_inputs_cast_sum.onnx | Bin 438 -> 438 bytes ...ranspose_product_cast_input2_cast_sum.onnx | Bin 557 -> 557 bytes ..._inputs_transpose_product_cast_inputs.onnx | Bin 565 -> 565 bytes ...spose_product_cast_inputs_cast_input2.onnx | Bin 620 -> 620 bytes ...duct_cast_inputs_cast_input2_cast_sum.onnx | Bin 667 -> 667 bytes ...pose_product_cast_inputs_cast_product.onnx | Bin 620 -> 620 bytes ..._cast_inputs_cast_product_cast_input2.onnx | Bin 675 -> 675 bytes ...uts_cast_product_cast_input2_cast_sum.onnx | Bin 722 -> 722 bytes ...uct_cast_inputs_cast_product_cast_sum.onnx | Bin 667 -> 667 bytes ...ranspose_product_cast_inputs_cast_sum.onnx | Bin 612 -> 612 bytes ...inputs_transpose_product_cast_product.onnx | Bin 510 -> 510 bytes ...pose_product_cast_product_cast_input2.onnx | Bin 565 -> 565 bytes ...uct_cast_product_cast_input2_cast_sum.onnx | Bin 612 -> 612 bytes ...anspose_product_cast_product_cast_sum.onnx | Bin 557 -> 557 bytes ...ose_inputs_transpose_product_cast_sum.onnx | Bin 502 -> 502 bytes ...ranspose_product_cast_input2_cast_sum.onnx | Bin 429 -> 429 bytes ...mul_add_transpose_product_cast_inputs.onnx | Bin 437 -> 437 bytes ...spose_product_cast_inputs_cast_input2.onnx | Bin 492 -> 492 bytes ...duct_cast_inputs_cast_input2_cast_sum.onnx | Bin 539 -> 539 bytes ...pose_product_cast_inputs_cast_product.onnx | Bin 492 -> 492 bytes ..._cast_inputs_cast_product_cast_input2.onnx | Bin 547 -> 547 bytes ...uts_cast_product_cast_input2_cast_sum.onnx | Bin 594 -> 594 bytes ...uct_cast_inputs_cast_product_cast_sum.onnx | Bin 539 -> 539 bytes ...ranspose_product_cast_inputs_cast_sum.onnx | Bin 484 -> 484 bytes ...pose_product_cast_product_cast_input2.onnx | Bin 437 -> 437 bytes ...uct_cast_product_cast_input2_cast_sum.onnx | Bin 484 -> 484 bytes .../propagate_cast/matmul_cast_inputs.onnx | Bin 311 -> 311 bytes .../matmul_cast_inputs_cast_product.onnx | Bin 366 -> 366 bytes .../matmul_transpose_inputs_cast_inputs.onnx | Bin 439 -> 439 bytes ...spose_inputs_cast_inputs_cast_product.onnx | Bin 494 -> 494 bytes .../matmul_transpose_inputs_cast_product.onnx | Bin 384 -> 384 bytes ..._inputs_transpose_product_cast_inputs.onnx | Bin 503 -> 503 bytes ...pose_product_cast_inputs_cast_product.onnx | Bin 558 -> 558 bytes ...inputs_transpose_product_cast_product.onnx | Bin 448 -> 448 bytes .../matmul_transpose_product_cast_inputs.onnx | Bin 375 -> 375 bytes ...pose_product_cast_inputs_cast_product.onnx | Bin 430 -> 430 bytes ...wo_outputs_second_matmul_add_products.onnx | Bin 0 -> 507 bytes .../matmul_two_outputs_second_matmul_sum.onnx | Bin 0 -> 500 bytes ...tmul_two_outputs_transpose_after_cast.onnx | Bin 439 -> 371 bytes ...ts_transpose_after_cast_second_matmul.onnx | Bin 602 -> 520 bytes ...after_cast_second_matmul_add_products.onnx | Bin 0 -> 589 bytes ...utputs_transpose_after_cast_transpose.onnx | Bin 0 -> 439 bytes ...se_after_cast_transpose_second_matmul.onnx | Bin 0 -> 602 bytes ..._transpose_second_matmul_add_products.onnx | Bin 0 -> 671 bytes ...mul_two_outputs_transpose_before_cast.onnx | Bin 449 -> 381 bytes ...s_transpose_before_cast_second_matmul.onnx | Bin 622 -> 530 bytes ...efore_cast_second_matmul_add_products.onnx | Bin 0 -> 599 bytes ...tputs_transpose_before_cast_transpose.onnx | Bin 0 -> 449 bytes ...e_before_cast_transpose_second_matmul.onnx | Bin 0 -> 622 bytes ..._transpose_second_matmul_add_products.onnx | Bin 0 -> 691 bytes 77 files changed, 296 insertions(+), 151 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index c3af6a8aae..afed76eea6 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -11,10 +11,21 @@ using namespace onnxruntime::common; namespace onnxruntime { // NodeArg to Select consumer node map. typedef std::unordered_map> NodeArgToConsumerMap; -static std::string GetName(const std::pair>& p) { - return p.first->Name(); -}; +// ConcatNames +// Collects all the names from the pointers of the objects stores in the container class C +// the class should have a member functions returning a string (or a ref). +template +static std::string ConcatNames( + C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) { + std::vector names; + std::transform(items.begin(), items.end(), back_inserter(names), f); + return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; }); +} + +static std::string GetName(const std::pair>& p) { + return p.first->Name() + " feeding " + ConcatNames(p.second) + "; "; +}; // The collection fp16_allow_ops, specifies for a given propagate_cast_ops level, a vector of node op_types that // the code is allowed to propage Cast operations cross. The user may specify a custom list of optypes using level 0. // The opcodes are split into multiple levels. Cast propagation is done based on the level. Level 2 op code @@ -65,7 +76,7 @@ static Status InsertCastNodes(Graph& graph, // data_type is the data type of the Cast output. TensorProto_DataType data_type = is_fp16 ? TensorProto_DataType_FLOAT16 : TensorProto_DataType_FLOAT; TypeProto type_proto; - bool is_node_arg_cast_output = IsType(*node_arg, data_type); + bool is_node_arg_cast_output = IsType(*node_arg, data_type); // true if the producer node_arg is being replaced TensorProto_DataType new_node_arg_data_type = data_type; if (is_node_arg_cast_output) { @@ -94,31 +105,59 @@ static Status InsertCastNodes(Graph& graph, inserted_node_names.insert(cast.Name()); Node* producer = graph.GetMutableProducerNode(node_arg->Name()); std::vector consumers = graph.GetMutableConsumerNodes(node_arg->Name()); - int output_index = (nullptr != producer) ? optimizer_utils::IndexOfNodeOutput(*producer, *node_arg) : -1; - // Update consumers of node_arg to use the output of the cast node - int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output); - for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) { - if (nullptr != consumer && std::find(nodes.begin(), nodes.end(), consumer) != nodes.end() && - std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { - auto& consumer_inputs = consumer->MutableInputDefs(); - int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg); - if (nullptr != producer) { + std::vector other_nodes; + if (nullptr != producer) { + int output_index = optimizer_utils::IndexOfNodeOutput(*producer, *node_arg); + for (Node* consumer : consumers) { + // Removed the edges to the consumers, getting casted. + if (std::find(nodes.begin(), nodes.end(), consumer) != nodes.end()) { + int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg); graph.RemoveEdge(producer->Index(), consumer->Index(), output_index, input_index); + } else { + other_nodes.push_back(consumer); } - std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output); - graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index); + } + if (is_node_arg_cast_output) { + // Replace the node_arg with the new_node_arg in the producer outputs + auto& producer_outputs = producer->MutableOutputDefs(); + std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input); + graph.UpdateProducerNode(cast_input.Name(), producer->Index()); } } - if (nullptr != producer) { - auto& producer_outputs = producer->MutableOutputDefs(); - // The following replacement is necessary in case where the output of the cast node is original - // output of the producer, for example the original output of the producer may be the graph output. - std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input); - graph.UpdateProducerNode(cast_input.Name(), producer->Index()); - int input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input); - graph.AddEdge(producer->Index(), cast.Index(), output_index, input_index); + // Update consumers of node_arg to use the output of the cast node + int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output); + for (Node* consumer : consumers) { + if (std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { + if (std::find(nodes.begin(), nodes.end(), consumer) == nodes.end()) { + // Consumers not getting casted need to replace input-def if the producer's output-def is changed + if (is_node_arg_cast_output) { + auto& consumer_inputs = consumer->MutableInputDefs(); + std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_output, &cast_input); + } + } else { + // Consumers getting casted need to get new edges from the new cast node.. + int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg); + if (!is_node_arg_cast_output) { + auto& consumer_inputs = consumer->MutableInputDefs(); + std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output); + } + graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index); + } + } } + // Complete the input/output connections to the new cast node, and update the graph. + other_nodes.push_back(&cast); graph.UpdateProducerNode(cast_output.Name(), cast.Index()); + graph.UpdateConsumerNodes(cast_output.Name(), nodes); + graph.UpdateConsumerNodes(cast_input.Name(), other_nodes); + if (nullptr != producer) { + int cast_input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input); + int output_index = optimizer_utils::IndexOfNodeOutput(*producer, cast_input); + graph.AddEdge(producer->Index(), cast.Index(), output_index, cast_input_index); + if (is_node_arg_cast_output) { + graph.UpdateProducerNode(cast_input.Name(), producer->Index()); + } + } } return Status::OK(); } @@ -138,7 +177,7 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector casts, std:: if (producer) { if (graph.IsOutput(cast_output)) { // cast_output is a graph output. Replace the cast node with an Identity operator unless node - // has other outputs. + // has no other outputs. if (producer->GetOutputEdgesCount() == 1) { int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input); graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index); @@ -146,11 +185,11 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector casts, std:: std::replace(outputs.begin(), outputs.end(), cast_input, cast_output); graph.UpdateProducerNode(cast_output->Name(), producer->Index()); } else { - (void) graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"), - "Identity", - "Created as a place-holder for a graph output", - {cast_input}, - {cast_output}); + (void)graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"), + "Identity", + "Created as a place-holder for a graph output", + {cast_input}, + {cast_output}); } } } @@ -183,19 +222,21 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector casts, std:: // and the second cast is from FLOAT to FLOAT16. // Condition: The parent cast should have only one output // The inputs is Cast to FLOAT16 -static bool RemoveBackToBackCasts(Graph& graph, Node* node, +static bool RemoveBackToBackCasts(Graph& graph, Node* parent, std::deque& removed_nodes, const logging::Logger& logger) { - ORT_ENFORCE(IsCastTo(node, TensorProto::FLOAT)); + ORT_ENFORCE(IsCastTo(parent, TensorProto::FLOAT)); bool modified = false; - if (graph_utils::CanRemoveNode(graph, *node, logger)) { - NodeArg* cast_output = node->MutableOutputDefs()[0]; - for (Node* child : graph.GetMutableConsumerNodes(cast_output->Name())) { + if (graph_utils::CanRemoveNode(graph, *parent, logger)) { + NodeArg* cast_output = parent->MutableOutputDefs()[0]; + std::vector children = graph.GetMutableConsumerNodes(cast_output->Name()); + if (children.size() == 1) { + Node* child = children[0]; if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) { if (IsCastTo(child, TensorProto::FLOAT16)) { // The parent and child cancell out - LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << node->Name() << " and " << child->Name(); - RemoveCastNodesChain(graph, {node, child}, removed_nodes); + LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << parent->Name() << " and " << child->Name(); + RemoveCastNodesChain(graph, {parent, child}, removed_nodes); modified = true; } else if (IsCastTo(child, TensorProto::FLOAT)) { // Child is a duplicate of parent @@ -204,6 +245,45 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node, modified = true; } } + } else { + NodeArg* parent_input = parent->MutableInputDefs()[0]; + const Node* producer = graph.GetProducerNode(parent_input->Name()); + int producer_output_index = producer ? optimizer_utils::IndexOfNodeOutput(*producer, *parent_input) : -1; + std::vector new_consumers; + for (Node* child : children) { + if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) { + if (IsCastTo(child, TensorProto::FLOAT16)) { + // The parent and child cancell out + // Remove the child node without effecting the other nodes. + // move all the consumers to the producer. + LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name(); + NodeArg* child_output = child->MutableOutputDefs()[0]; + for (Node* consumer : graph.GetMutableConsumerNodes(child_output->Name())) { + std::vector& consumer_inputs = consumer->MutableInputDefs(); + int output_index = optimizer_utils::IndexOfNodeOutput(*child, *child_output); + int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *child_output); + graph.RemoveEdge(child->Index(), consumer->Index(), output_index, input_index); + std::replace(consumer_inputs.begin(), consumer_inputs.end(), child_output, parent_input); + if (nullptr != producer) { + graph.AddEdge(producer->Index(), consumer->Index(), producer_output_index, input_index); + } + new_consumers.push_back(consumer); + } + modified = true; + removed_nodes.push_front(child->Index()); + } else if (IsCastTo(child, TensorProto::FLOAT)) { + // Child is a duplicate of parent + LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name(); + RemoveCastNodesChain(graph, {child}, removed_nodes); + modified = true; + } + } + } + if (new_consumers.size() > 0) { + std::vector consumers = graph.GetMutableConsumerNodes(parent_input->Name()); + std::copy(new_consumers.begin(), new_consumers.end(), back_inserter(consumers)); + graph.UpdateConsumerNodes(parent_input->Name(), consumers); + } } } return modified; @@ -216,11 +296,18 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node, // nodearg is traversed not more than once. static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node, NodeArgToConsumerMap& require_cast, + NodeArgToConsumerMap& require_cast_fp32, std::unordered_set& require_type_change, std::deque& removed_nodes, size_t level) { Node* node = graph.GetMutableProducerNode(node_arg->Name()); - if (graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + // If the Cast input feeds more than one node or the cast node feeds a graph output and at least one + // node then it cannot propagate. + size_t consumer_node_count = graph.GetConsumerNodes(node_arg->Name()).size(); + if (level < 2 && (consumer_node_count > 1 || + nullptr != node && + consumer_node_count > 0 && + graph.IsOutput(node_arg))) { require_cast[node_arg].push_back(dst_node); } else if (node == nullptr) { // The graph inputs don't have the producer nodes @@ -244,18 +331,30 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node, // TODO: If the specified optimization is greater than 1 then insert a Cast to the // other output_def and still propagate FP16 cast up the graph. if (output_def != node_arg) { - if (IsType(*output_def, TensorProto_DataType_FLOAT)) { + if (IsType(*output_def, TensorProto_DataType_FLOAT) && graph.GetConsumerNodes(output_def->Name()).size() > 0) { require_cast[node_arg].push_back(dst_node); return; } } } + if (level >= 2) { + for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) { + if (nullptr != consumer && consumer != dst_node && consumer->OpType() != "Cast" && + std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { + require_cast_fp32[node_arg].push_back(consumer); + } + } + if (graph.IsOutput(node_arg)) { + require_cast_fp32[node_arg] = std::vector(); + } + } for (NodeArg* node_input : node->MutableInputDefs()) { if (IsType(*node_input, TensorProto_DataType_FLOAT) && require_cast.find(node_input) == require_cast.end() && require_type_change.find(node_input) == require_type_change.end()) { - SearchUpstream(graph, node_input, node, require_cast, require_type_change, removed_nodes, level); - if (require_cast.find(node_input) == require_cast.end()) { + SearchUpstream(graph, node_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level); + if (require_cast.find(node_input) == require_cast.end() && + require_cast_fp32.find(node_input) == require_cast_fp32.end()) { require_type_change.insert(node_input); } } @@ -318,17 +417,6 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg, } } -// ConcatNames -// Collects all the names from the pointers of the objects stores in the container class C -// the class should have a member functions returning a string (or a ref). -template -static std::string ConcatNames( - C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) { - std::vector names; - std::transform(items.begin(), items.end(), back_inserter(names), f); - return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; }); -} - // Change the elem_type of the given NodeArgs from FLOAT to FLOAT16. static void ChangeTypeToFP16(Graph& graph, std::unordered_set& require_type_change, bool is_forward, const logging::Logger& logger) { ONNX_NAMESPACE::TypeProto type_proto; @@ -401,25 +489,30 @@ static bool PropagateBackwards(Graph& graph, Node* node, bool modified = false; ORT_ENFORCE(nullptr != node); NodeArgToConsumerMap require_cast; + NodeArgToConsumerMap require_cast_fp32; NodeArg* cast_input = node->MutableInputDefs()[0]; - const Node* cast_input_producer = graph.GetProducerNode(cast_input->Name()); // nullptr for graph outputs - // If the Cast input feeds more than one node or the cast node feeds a graph output and at least one - // node then it cannot propagate. - size_t consumer_node_count = graph.GetConsumerNodes(cast_input->Name()).size(); - if (consumer_node_count > 1 || - (nullptr != cast_input_producer && graph.GetNodeOutputsInGraphOutputs(*cast_input_producer).size() > 0 && consumer_node_count > 0)) { - return modified; + std::unordered_set require_type_change; + if (graph.IsOutput(cast_input)) { + return false; } - std::unordered_set require_type_change = {cast_input}; - SearchUpstream(graph, cast_input, node, require_cast, require_type_change, removed_nodes, level); - if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end()) { + SearchUpstream(graph, cast_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level); + if (require_cast_fp32.empty()) { + require_type_change.insert(cast_input); + } + // TODO need a huristic when to insert FP32 Cast + if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end() /* && require_cast.size() >= require_cast_fp32.size() */) { // Remove Cast operation - LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name(); + if (require_cast_fp32.size() > 0) { + InsertCastNodes(graph, require_cast_fp32, false, removed_nodes); + LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted FP32 Cast nodes " + << ConcatNames(require_cast_fp32, GetName); + } RemoveCastNodesChain(graph, {node}, removed_nodes); + LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name(); InsertCastNodes(graph, require_cast, true, removed_nodes); - ChangeTypeToFP16(graph, require_type_change, false, logger); LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted Cast nodes " << ConcatNames(require_cast, GetName); + ChangeTypeToFP16(graph, require_type_change, false, logger); LOGS(logger, VERBOSE) << "PropagateBackwards: Changed the type from float to float16 : " << ConcatNames>(require_type_change); modified = true; @@ -585,6 +678,7 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, std::vector casts; // Cast nodes to propagate. std::vector& outputs = node->MutableOutputDefs(); std::unordered_set require_type_change; + NodeArgToConsumerMap non_cast_consumers_map; // TODO Here we require the all floating point outputs are consumer by an immediate // child cast node. for (auto iter = outputs.begin(); iter != outputs.end() && all_float_outputs_have_casts; ++iter) { @@ -594,22 +688,28 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, } has_float_outputs = true; std::vector consumers = graph.GetMutableConsumerNodes(output->Name()); - for (auto node_iter = consumers.begin(); node_iter != consumers.end() && all_float_outputs_have_casts; ++node_iter) { + for (auto node_iter = consumers.begin(); node_iter != consumers.end() && (level >= 2 || all_float_outputs_have_casts); ++node_iter) { Node* consumer = *node_iter; if (nullptr != consumer && - std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end() && - IsCastTo(consumer, TensorProto::FLOAT16)) { - casts.push_back(consumer); - continue; + std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { + if (IsCastTo(consumer, TensorProto::FLOAT16)) { + casts.push_back(consumer); + continue; + } else { + non_cast_consumers_map[output].push_back(consumer); + } } all_float_outputs_have_casts = false; } - require_type_change.insert(output); + if (non_cast_consumers_map.empty()) { + require_type_change.insert(output); + } } - if (has_float_outputs && all_float_outputs_have_casts && casts.size() > 1) { + if (has_float_outputs && (level >= 2 || all_float_outputs_have_casts) && casts.size() > 1) { LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Removed Cast nodes " << ConcatNames>(casts) - << " feeding the same compute node " << node->Name(); + << " feeding from the same compute node " << node->Name(); + InsertCastNodes(graph, non_cast_consumers_map, false, removed_nodes); for (Node* cast : casts) { RemoveCastNodesChain(graph, {cast}, removed_nodes); } @@ -623,6 +723,8 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, ChangeTypeToFP16(graph, require_type_change, false, logger); LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to " << ConcatNames(node_args_map, GetName); + LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted FP32 Cast node to " + << ConcatNames(non_cast_consumers_map, GetName); modified = true; } } @@ -705,16 +807,6 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level } } - // Propagate FP16 Casts backward - for (auto node_index : node_topology_list) { - Node* node = graph.GetNode(node_index); - if (nullptr != node && - std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() && - IsCastTo(node, TensorProto::FLOAT16)) { - local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger); - } - } - // Propagate FP16 Casts from outputs to inputs for (auto node_index : node_topology_list) { Node* node = graph.GetNode(node_index); @@ -724,6 +816,16 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level } } + // Propagate FP16 Casts + for (auto node_index : node_topology_list) { + Node* node = graph.GetNode(node_index); + if (nullptr != node && + std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() && + IsCastTo(node, TensorProto::FLOAT16)) { + local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger); + } + } + // Propagate FP32 Casts from inputs to outputs for (auto node_index : node_topology_list) { Node* node = graph.GetNode(node_index); @@ -755,7 +857,8 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level LOGS(logger, INFO) << "Nodes Converted to FP16:"; std::for_each(converted_node_names.begin(), converted_node_names.end(), [removed_node_names, logger](std::string name) { - if (removed_node_names.find(name) == removed_node_names.end()) { LOGS(logger, INFO) << name; } }); + if (removed_node_names.find(name) == removed_node_names.end() && inserted_node_names.find(name) == inserted_node_names.end()) { + LOGS(logger, INFO) << name; } }); } inserted_node_names.clear(); converted_node_names.clear(); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index aab654b0b7..4832428c35 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4182,7 +4182,30 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 1, allow_matmul_transpose}, {MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul.onnx", 2, allow_matmul}, {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose}, - {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}}; + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 2, allow_matmul}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs.onnx", 1, allow_matmul, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast.onnx", 1, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 3, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 3, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 3, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}}; // Create a temporary directory, which will be deleted automatically, to save/load the transformed models. TemporaryDirectory temp_dir{ORT_TSTR("propagate_casts_test_output_dir")}; diff --git a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py index 41ef15af4e..1a44da9e82 100644 --- a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py +++ b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py @@ -135,57 +135,71 @@ def do_cast_inputs(input_0, input_1, nodes): to = input_cast_type)]) return "cast_"+input_0, "cast_"+input_1 def do_transpose_inputs(input_0, input_1, nodes): - nodes.extend([helper.make_node("Transpose", [input_0], ["transpose_"+input_0], "Transpose_0"), - helper.make_node("Transpose", [input_1], ["transpose_"+input_1], "Transpose_1")]) - return "transpose_"+input_0, "transpose_"+input_1 + nodes.extend([helper.make_node("Transpose", [input_0], ["input_transpose_0"], "Transpose_0"), + helper.make_node("Transpose", [input_1], ["input_transpose_1"], "Transpose_1")]) + return "input_transpose_0", "input_transpose_1" + def do_cast_product(product, nodes): - nodes.append(helper.make_node( + nodes.insert(1,helper.make_node( "Cast", [product], - ["cast" + product], + ["product_cast"], "Cast_2", to = TensorProto.FLOAT16)) - return "cast_"+product + return "product_cast" -def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2): - nodes = [ - helper.make_node( +def do_transpose_product(product, nodes): + if transpose_product: + nodes.append(helper.make_node("Transpose", [product], ["product_transpose"], "Transpose_2")) + return "product_transpose" + +def do_cast_sum(sum, nodes, type): + nodes.append(helper.make_node( + "Cast", + [sum], + ["cast_sum"], + "Cast_3", + to = type)) + return "cast_sum" + +def do_cast_input2(input_2, nodes, type): + nodes.append(helper.make_node( + "Cast", + [input_2], + ["cast_"+input_2], + "Cast_4", + to = type)) + return "cast_"+input_2 + +def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2, transpose_inputs_before_cast=False): + input_0 = "input_0" + input_1 = "input_1" + product = "product" + nodes = [] + if transpose_inputs_before_cast: + if transpose_inputs: + input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes) + if cast_inputs: + input_0, input_1 = do_cast_inputs(input_0, input_1, nodes) + else: + if cast_inputs: + input_0, input_1 = do_cast_inputs(input_0, input_1, nodes) + if transpose_inputs: + input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes) + nodes.append(helper.make_node( "MatMul", - ["input_transpose_0" if transpose_inputs else ("cast_input_0" if cast_inputs else "input_0"), - "input_transpose_1" if transpose_inputs else ("cast_input_1" if cast_inputs else "input_1")], - ["product"], + [input_0, + input_1], + [product], "MatMul_0") - ] + ) + if transpose_product: + product = do_transpose_product(product, nodes) if cast_product: - nodes.append(helper.make_node( - "Cast", - ["product_transpose" if transpose_product else "product"], - ["product_cast"], - "Cast_2", - to = TensorProto.FLOAT16)) + product = do_cast_product(product, nodes) - if cast_inputs: - input_cast_type = TensorProto.FLOAT - nodes.extend([helper.make_node( - "Cast", - ["input_0"], - ["cast_input_0"], - "Cast_0", - to = TensorProto.FLOAT), - helper.make_node( - "Cast", - ["input_1"], - ["cast_input_1"], - "Cast_1", - to = TensorProto.FLOAT)]) - - if transpose_inputs: - nodes.extend([helper.make_node("Transpose", ["cast_input_0" if cast_inputs else "input_0"], ["input_transpose_0"], "Transpose_0"), - helper.make_node("Transpose", ["cast_input_1" if cast_inputs else "input_1"], ["input_transpose_1"], "Transpose_1")]) - - if transpose_product: - nodes.append(helper.make_node("Transpose", ["product"], ["product_transpose"], "Transpose_2")) + output = product input_type = TensorProto.FLOAT16 if cast_inputs else TensorProto.FLOAT output_type = flip_type(cast_sum, flip_type(cast_product, flip_type(cast_inputs, input_type))) @@ -196,28 +210,21 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc "input_1", input_type, ['N', 'N']) ] if insert_add: + + input_2 = "input_2" add_input_type = flip_type(True, input_type) if cast_inputs != cast_product else input_type add_input_type = flip_type(cast_input2, add_input_type) - inputs.append(helper.make_tensor_value_info("input_2", add_input_type, ['N', 'N'])) - nodes.append(helper.make_node("Add", ["product_cast" if cast_product else ("product_transpose" if transpose_product else "product"), "cast_input_2" if cast_input2 else "input_2"], ["sum"], "Add_0")) - if cast_sum: - input2_cast_type = flip_type(True, flip_type(cast_input2, add_input_type)) - nodes.append(helper.make_node( - "Cast", - ["sum"], - ["cast_sum"], - "Cast_3", - to = input2_cast_type)) + inputs.append(helper.make_tensor_value_info(input_2, add_input_type, ['N', 'N'])) + add_output = "sum" if cast_input2: - nodes.append(helper.make_node( - "Cast", - ["input_2"], - ["cast_input_2"], - "Cast_4", - to = flip_type(True, add_input_type))) + input_2 = do_cast_input2(input_2, nodes, flip_type(True, add_input_type)) + nodes.append(helper.make_node("Add", [product, input_2], [add_output], "Add_0")) + if cast_sum: + add_output = do_cast_sum(add_output, nodes, flip_type(not cast_input2, add_input_type)) + output = add_output outputs = [ helper.make_tensor_value_info( - "cast_sum" if cast_sum else "sum" if insert_add else ("product_cast" if cast_product else ("product_transpose" if transpose_product else "product")), output_type, ['N', 'N']) + output, output_type, ['N', 'N']) ] save(model_path + ("_transpose_inputs" if transpose_inputs else "") + @@ -229,11 +236,12 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc nodes, inputs, outputs, []) def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul): - def do_transpose(output_0, output_1, nodes): - nodes.extend([helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"), - helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")]) + def do_transpose(output_0, output_1, transpose, nodes): + nodes.append(helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0")) output_0 = "transpose_0_"+output_0 - output_1 ="transpose_1_"+output_1 + if transpose > 1: + nodes.append(helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")) + output_1 ="transpose_1_"+output_1 return output_0, output_1 input_type = TensorProto.FLOAT input_0 = "input_0" @@ -262,9 +270,16 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second "MatMul_1")) outputs.append(helper.make_tensor_value_info( "second_"+output, input_type, ['M', 'N'])) - - if transpose and transpose_before_cast: - output_0, output_1 = do_transpose(output_0, output_1, nodes) + if add_products: + nodes.append(helper.make_node( + "Add", + [output, "second_"+output], + ["sum"], + "Add_0")) + outputs.append(helper.make_tensor_value_info( + "sum", input_type, ['M', 'N'])) + if transpose > 0 and transpose_before_cast: + output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes) nodes.append(helper.make_node( "Cast", @@ -283,8 +298,8 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second to = TensorProto.FLOAT16)) output_1 = "cast_1_"+output_1 - if transpose and not transpose_before_cast: - output_0, output_1 = do_transpose(output_0, output_1, nodes) + if transpose > 0 and not transpose_before_cast: + output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes) outputs.extend([ helper.make_tensor_value_info( @@ -292,8 +307,10 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second helper.make_tensor_value_info( output_1, flip_type(second_matmul, input_type), ['M', 'N']) ]) - model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose else "" + model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose > 0 else "" + model_path += "_transpose" if transpose > 1 else "" model_path += "_second_matmul" if second_matmul else "" + model_path += "_add_products" if add_products else "" save(model_path, nodes, inputs, outputs, []) for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)): @@ -305,7 +322,9 @@ for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, gen_fuse_sibling_casts("fuse_sibling_casts") gen_fuse_back2back_casts("fuse_back2back_casts") -for (transpose, transpose_before_cast, second_matmul) in list(itertools.product([False, True], repeat=3)): +for (transpose, transpose_before_cast, second_matmul, add_products) in list(itertools.product([0,1,2], [False, True], [False, True], [False, True])): if not transpose and transpose_before_cast: continue + if not second_matmul and add_products: + continue gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul) \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_input2_cast_sum.onnx index 9ce67f5f39401315d7441448461911b7235e5507..106aec65490ee1770ee56332ece1eb32c562446a 100644 GIT binary patch delta 15 XcmaFM^p>S0IBf?! delta 15 XcmaFM^piHgw^pKDIsVl-KRaSi}vlnA>3 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2_cast_sum.onnx index a28fc9f81fdb85f188413972ba0a11c4782ac3f4..f8cb8779c468bbd65b265e673c17196ba0fed2fe 100644 GIT binary patch delta 23 fcmcc3e4BZK#Kgdqi4sneg%~v^ZV8!e#kddvX(9-0 delta 22 ecmcc3e4BZK#Kan>iHgw^p9fFeVl>%`aUlS19SL#( diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product.onnx index a68951f94c08f00ac4d1a5a33390315f2c71a6cb..abcd7aabe340ec584de1403b1dae86289ada9d0c 100644 GIT binary patch delta 21 dcmZ3(yoPy##KdOPi6Isf_i0X)u$=h(C;(cS2?PKD delta 21 dcmZ3(yoPy##Kd)$6TPe_D(XzE*PrFi4x9}r5Oz;9t)Z5!?+LtZlVa; delta 22 ecmaFM{FZrw#Kacoi4r*ze+ExHW;EG{aUlS6Dhb^H diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs.onnx index c535eccccfca4b5ca7d367e9711db025c92c4651..ff54f2230a523587f0128a26092227711b3c3bc4 100644 GIT binary patch delta 16 Ycmey${FQlv#Kh=Z6D6D{n=&2%06r21)Bpeg delta 16 Ycmey${FQlv#Kcd|6D4j?1c0C^AzIRF3v diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2_cast_sum.onnx index c482513f6ab22d25030b0aa5f162a1a0a6e8c85c..5fd2cf1a6da9865df29eb7ddd2539933443d78a4 100644 GIT binary patch delta 24 gcmcc3a+_s>#Kh=Z6D52mPh!-VY{D2ac^BhC0DKt#KccN6BRE`PG=0BY{FfFO$asoW~LjQvd(} delta 31 ncmaFN@|b0U#AFplw~4+M6BTtP)-RnrlTmMS8l%zVfS(EVrp|}$aHUIzs delta 31 ncmbQlI*E0H#AFplw~4+M6BTtP)-RnrlQCj)8l%zVkBkcesIm)H diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_sum.onnx index e3f2de633b22083a841c8b19daf11db3ee725d01..1abc0a13445c56f190951e1fb59993656d644f2b 100644 GIT binary patch delta 22 ecmcc3a+_s>#Kd0HiE$PaPt2bv;XFB=@e2TMxe2rY delta 23 fcmcc3a+_s>#AFpl=ZRj{6BTtP)-Ro$&iDlYX0ZtR diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_sum.onnx index 11d333f7f71e93d0cd019f57ee19c5d6421c4eb3..8a3d1f35cca262f8238ce05739c1695a8ecf7ea3 100644 GIT binary patch delta 16 YcmZ3&vV>)V#Kh=Z6D6D{n=-x!05snQV*mgE delta 16 YcmZ3&vV>)V#Kcd|6D4j*&*)B+aGv<{C;(qY32OiV delta 21 dcmdnTypMT;#KdjR6aB0wD(X&bFq`=EC;(m=2}A$@ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2.onnx index 1174c138c98135a456b1bdd9d2c91c446ea33a76..82fa7a8f5019f2cc848174cdc228744b830356c9 100644 GIT binary patch delta 29 lcmey${FQlv#Kc~+iP6>*&*)B+aGR{ns5kM4{^Ur;IRLl@3%dXS delta 29 lcmey${FQlv#KdiG69X+KD(X&bFq^E+s5kM4(d0)V#Kc~+iP6>*&*)B+aGM;;s5kLP)V#KdiG69X+KD(X&bFq^E+7%}mO(d1^vg#e}?3s3+6 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_sum.onnx index 704bea31dc2037d790cac9b49834690168451827..2e308ec1dc47dc562e0a37a690a971df3151b7dc 100644 GIT binary patch delta 21 dcmaFM{FZrw#Kc~+iP6>*&*)B+aGv<{3jk)!3HSg2 delta 21 dcmaFM{FZrw#KdjR6aB0wD(X&bFq`=E3jk%H3DE!m diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_sum.onnx index 66aeb03cc337f781f861b756b62f28e4f497556d..16ac40b03d25d4ac0fb9a1136c340bc2f75a4174 100644 GIT binary patch delta 15 XcmdnSyp4H+#Kh>Fi4x8ekG}^1FJuO$ delta 15 XcmdnSyp4H+#Kacoi4r*zkG}^1Ffj(t diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_input2_cast_sum.onnx index 5c0439bf1f3dbfc85fa81a3e7413a91ad32b1e33..78d681cdfb2f8f17d0859010b193b3d996b28a2c 100644 GIT binary patch delta 26 icmZ3>vX*6n#Kh>Fi4txTPZ>{+W;C4mCvI{N<3a$6Fbb>y delta 26 icmZ3>vX*6n#Kac2iPD)9kDE=_XN;Zr$7pg7<3a$7l?tf< diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs.onnx index 4fecbb30614d2daf68168614c6612ff823faa940..e9ac197963eb042f3e81bca5b50141449a3c92d8 100644 GIT binary patch delta 16 YcmdnWvXy0m#Kh=Z6D6D%n=Í&fLrvLx| delta 16 YcmdnWvXy0m#Kcd|6D4jfHf3Z406Dn^J^%m! diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2.onnx index 3ffefa026a393847324bffdb0cfd13b046242164..965e835a4efa31b845eda5ab45096aa6fe6a1bbc 100644 GIT binary patch delta 28 kcmaFE@`h!C#Kh=Z6D6D{n=+bBp3i7FIg8O`@&(2@0Hmx6&;S4c delta 28 kcmaFE@`h!C#Kcd|6D4jPx# delta 35 tcmV+;0Nnqh1)~L!7?UmmLy<~Fks&FOhNzQq0XCD_0WOoK0Wy;Z0-W413^D)! diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx index 2e4809525b46e23d7856698992ae404224703581..e7eab92e67e04f779e7c65936d84c19cb5698e5e 100644 GIT binary patch delta 33 rcmV++0N(%71=0nO7?GDUky}NP(w~tSLz8g diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_sum.onnx index 0e1ea8a88246a2a719ea6448432a64cfe0a48187..a10d146441fe9d4869cc775791ba0cd8f4689e0d 100644 GIT binary patch delta 23 fcmbQuI-7NZ#Kh@F6Qf-xUYb8q!g+HcqZT6oZWsu% delta 24 gcmbQuI-7NZ#AH22=ZRjf6BRWlwyoG)$f(5#0A$t&_y7O^ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_sum.onnx index b2da72a7f5b2d7c9c891c6d4a780659e78eff054..356b4d96ca483a308e6890a8554b6a3324b37f5e 100644 GIT binary patch delta 17 ZcmaFD@`PoA#Kh=Z6D6ECn=UQE##ZqtWDUF=Da>qtWEUj0*wh^bC6d diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_sum.onnx index df847546d6462730de609bf0a35a4e7f4c64a93c..280017d93c7b53192c47e4d4a8497ee3a7884d1a 100644 GIT binary patch delta 23 fcmZ3>vX*6n#Kh@l6Qf-xUelc@;k=odQHv1(Yp@6W delta 23 fcmZ3>vX*6n#KZ&66a8E#D(X(`GTY3|sKp2XY8VIA diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_sum.onnx index 2aacbde9ca2d0543dc52e63d112613d177af49ac..e38777792e886ae8fd7cda35eb113689ef008e17 100644 GIT binary patch delta 16 Ycmeyy{Ec~n#Kh>Fi4txbPw6lM06d)riU0rr delta 16 Ycmeyy{Ec~n#Kac2iPD)HkLxo606m-rvj6}9 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_input2_cast_sum.onnx index 1d717ffdcc9eaf8c8c6a9f12fef0ef19b01af839..1979008fc145e5533b627da3a1171316ca5572f7 100644 GIT binary patch delta 16 YcmZ3>yq0-F?Zn%L6Kmoo^Dr(106TUDw*UYD delta 16 YcmZ3>yq0-F?Zox56Kjkn^Dr(106K04ivR!s diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs.onnx index 1645c1a173d27ff3fda32fb5248965897f55371c..1a2c650a871e23183cf50cb58dbce13792eb511d 100644 GIT binary patch delta 15 XcmdnWyp?%^#Kgdqi4v9zw=x0%Ew%-i delta 15 XcmdnWyp?%^#KaoQi4rLbw=x0%E{z4l diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2.onnx index 814425a94659d086aa4ef1b1c778b6a39bf68c9d..43b4db3ed6fd6fdd3d1f7f4650e47647e7172a61 100644 GIT binary patch delta 26 icmaFE{Dygg#Kgdqi4v9*x0+3sXEdDn%4Bi?;~W5w0SZ0< delta 26 icmaFE{Dygg#KaoQi4rLjx0+3sWi+1n%4l)`;~W5yqY7sL diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2_cast_sum.onnx index 251f8fb0fa6f593c2648ac83c884fe8d61766499..8793db8e810cba8eccb94db2e0908674c1607f83 100644 GIT binary patch delta 26 icmbQuGMi(^b delta 16 YcmdnWyp?&v&WW$}ChjnrEXgE6UCh;n=m#505*UI7ytkO delta 16 Ycmdnayq$T1#Kccd6UA>$HeqZA06Gc=uK)l5 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs_cast_product.onnx index f32d51d26ea507a03cf9064093ae49955e626555..f64a669d4bf10788276d0f1ec4946a44a754acc7 100644 GIT binary patch delta 22 ecmaFI{Em5o#Kd0HiE$PaPt2Pr?mRh-u>$~cG6`t_ delta 23 fcmaFI{Em5o#AFplr-`066BTtP)-Rcy#@GP>X-EjK diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_product.onnx index 87d687b778fc6596b8782444b971d9343c1faa67..91c23995eb88ec635f1852610ca4cd89df615b15 100644 GIT binary patch delta 21 dcmZo*ZeX4uF|pTdVzl+dGddH+ohSb20031=2(JJD delta 21 dcmZo*ZeX4uF>#yIL|>bUinZ;8f0v45mZvLBC4T(cumUIJg~DjlG?X|+DH6qS?uIdj>K1K+(SHZ; zjjAR0WXc3jqAo=ZwS>kJI$~l=mCv6)lcZW=R|VdTmUt|d)x(VO=&Qa5LtU3+7@(NJ zk2UBsw8PLILnhkfP>LeH#+V+f?>H??IO2onS_Hr*R;Z8NeMsBGogVJbK`&R=V1J&1 p&f11+=Nf(tT5F)S{jM>=`kL(!p8S{X1YViQ*LlWDA;krK>kmk+g1i6# literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa80d71b6ad376bd00025fc7f7cd5f08ff1754d1 GIT binary patch literal 500 zcmZutPj7-S6dz&)UmU?+H1o!8CK|i!wA)BW`yzGtQt@xgN~0$>v>)W_xq(k8gk;Pxlz1how+&VA|12O1$a kjw$(oGXY(qxkR<$cHM_3f3zOMGZXnT&sZs>IG}I+0hQN)_W%F@ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx index 4cc5ae341f7ac9bd2700324a36005fc459cde417..2f8ffcedec700481cd2c01e90bf10b20ae0d5f12 100644 GIT binary patch delta 23 fcmdna{F!NjDAQr4iQ+3K{-|S?=3<}R%4h@tXdDOb delta 65 zcmey&w4Hf^C{sK0MDdjpTxMMC1x5KOrO72i!X-tCdBp|!#i{Xz6Bmdu8cy8W$*<0Z LEF3@CkqMsZl6b@Tf};GC(&Q2$sgk0^yyAlV;#8m@s+<&e X2tsbMJmcEQa~OG<4C5!CXS4(W{}~;u diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx new file mode 100644 index 0000000000000000000000000000000000000000..640f179c0f17a9cc728f97a90d938b506bfc18f2 GIT binary patch literal 589 zcmaJ-&rgFe6qbR5J(aLtX7I*w$(Gf4^0w)1f$Xr87gC@a5@`u7iT{j=f2vSJM@r(Q z@4dY5$Lr_0DBz5JH7%26MK_WP`TXs`2HcX>txR`AgOOV|yl8V7dNGl4`d`CkW5 z18M~K1)>YYl@W7+_?JwhFKJMbVkqU4ci9MSf01cbFEQzo>PIt%I|L!8#0lqcnIUWzMrb#LTaTpA5J{hqAjmlZMfnAZ>4_z&AwX|O zNyEHqASA%W$H6GX#leWd^Z literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cbeb4ad7f649641703b585cee2be524e3efd3826 GIT binary patch literal 602 zcmaKoy=ucS6ou>5#LjJN98JO9JQfn7(y3#bEjV<@l*KrTS`4*hq-*Hobna8ucCaF3 z=%OFaIp5*lFd(rM;#)NW%Ze|NYkd9nVF^Jg>IU=m@Sr59l{_~LlQ1hVYrgVzI=#H$ zWpG*cNOLB|Id@akq@H0wl86<$Ei>lnTmMh}?tYrV1myyLya|jl)XC6sW(y9UB&Z6Z zYpHqL=~Ukxu39=X9~da6(5BD}#b}8|p%#VAyMf-7tc%YA`Jr9+@zE@g#1MrZ0k|bz z)*ZX6NNsV~iThIu&XR3U57?hp@XxTlAFM*h?~XP0Z+y6k?>BRJV6s|OB~wz%4~)G( DGX1Ts literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx new file mode 100644 index 0000000000000000000000000000000000000000..db0fa9e95676535598e0839ef162073d3411a058 GIT binary patch literal 671 zcmaKpK~IA)7=~rwV4q4@FEe=KxMa&}Jb80^TOd2^ThXSK>25q=!YXEJWT(p@w^ zpWtb9X?9PioT@$<8;8A&zPcqS6Oo8XGLmfH&zFxpg zvU~WkX44EdGI*MK0|Spj$ulY%Cd#g-(fa0e(V~SaV5F#qIt8C8N((g#)+lJ%Kj_F= zlYM5gJau35{-6ep0`NTx0ocfj`{i&Eu`V8ZG0CIJj7>n2ub`j19v;X^zJ-6LF!({c diQzZjJN54!xC(CfOSmJf-qsc2Ot23bSbq|~zK{R_ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx index a65e3a4e44613114c89f9b508e607915ddb99445..72219fbdde4b099c550ae6c307295241e2dc2570 100644 GIT binary patch delta 23 fcmX@e{FiBhDAQS{iQ*X(x1VH|=3<}R&u9byXP*b( delta 80 zcmey%bdY(1DAPpdiQ*YLW?bwAMfoYE$t6O;$kl-%1(GNTemSP5g7b?cWqwCUbD3-|mv%jzb2UEoo1(Y2@&%1g#c>?`++egEk#)B|4O z6dk!Z;q`LyDXaNQi+M&{cJB_D4-H(V%5x@>a1j`Ah`a*YUpJ zsj!d(wxuBe9cf{6?0PU5g1ZKcQg4#e*dR)8pvi71Zk?la1OF`D@?kJZ+wAX>Q_j|a VGv{)j7Y6w2{AC1jVcvGCl=YvC_S*7ThkzS)1l>#coIl)y9& z5WqPKigs8HB6h`9E3UU096h}2=>Y5P0{-c_^Mg}x@O%3@^)Ev>OD>l)xMH$eR3%eV I%Quq*Kcm~Wxc~qF literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1c4d092d897ee89e5c84a9a979f0b2d7d55dcc29 GIT binary patch literal 691 zcmaKp%};|c7{+C6gT3m)>Oq4yj!U-C#CY)LytzO;?Bs>;r6G})(2{ufAMsBeWl}1j zmwr9Z^LyL3rhzO<$Y)s*p5^I=r;I;*&0qsYMv96DZo9z9C`y`CF}F=GZ4xNSE z{E@?BIjrWMVXO?YE)aELr-8g%y_6>=O$QSZF2S zFC{I)_mHQD`nz6lW%H2*OH)GtI?_Bb>^l*w;$Db;