diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index c3af6a8aae..afed76eea6 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -11,10 +11,21 @@ using namespace onnxruntime::common; namespace onnxruntime { // NodeArg to Select consumer node map. typedef std::unordered_map> NodeArgToConsumerMap; -static std::string GetName(const std::pair>& p) { - return p.first->Name(); -}; +// ConcatNames +// Collects all the names from the pointers of the objects stores in the container class C +// the class should have a member functions returning a string (or a ref). +template +static std::string ConcatNames( + C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) { + std::vector names; + std::transform(items.begin(), items.end(), back_inserter(names), f); + return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; }); +} + +static std::string GetName(const std::pair>& p) { + return p.first->Name() + " feeding " + ConcatNames(p.second) + "; "; +}; // The collection fp16_allow_ops, specifies for a given propagate_cast_ops level, a vector of node op_types that // the code is allowed to propage Cast operations cross. The user may specify a custom list of optypes using level 0. // The opcodes are split into multiple levels. Cast propagation is done based on the level. Level 2 op code @@ -65,7 +76,7 @@ static Status InsertCastNodes(Graph& graph, // data_type is the data type of the Cast output. TensorProto_DataType data_type = is_fp16 ? TensorProto_DataType_FLOAT16 : TensorProto_DataType_FLOAT; TypeProto type_proto; - bool is_node_arg_cast_output = IsType(*node_arg, data_type); + bool is_node_arg_cast_output = IsType(*node_arg, data_type); // true if the producer node_arg is being replaced TensorProto_DataType new_node_arg_data_type = data_type; if (is_node_arg_cast_output) { @@ -94,31 +105,59 @@ static Status InsertCastNodes(Graph& graph, inserted_node_names.insert(cast.Name()); Node* producer = graph.GetMutableProducerNode(node_arg->Name()); std::vector consumers = graph.GetMutableConsumerNodes(node_arg->Name()); - int output_index = (nullptr != producer) ? optimizer_utils::IndexOfNodeOutput(*producer, *node_arg) : -1; - // Update consumers of node_arg to use the output of the cast node - int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output); - for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) { - if (nullptr != consumer && std::find(nodes.begin(), nodes.end(), consumer) != nodes.end() && - std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { - auto& consumer_inputs = consumer->MutableInputDefs(); - int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg); - if (nullptr != producer) { + std::vector other_nodes; + if (nullptr != producer) { + int output_index = optimizer_utils::IndexOfNodeOutput(*producer, *node_arg); + for (Node* consumer : consumers) { + // Removed the edges to the consumers, getting casted. + if (std::find(nodes.begin(), nodes.end(), consumer) != nodes.end()) { + int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg); graph.RemoveEdge(producer->Index(), consumer->Index(), output_index, input_index); + } else { + other_nodes.push_back(consumer); } - std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output); - graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index); + } + if (is_node_arg_cast_output) { + // Replace the node_arg with the new_node_arg in the producer outputs + auto& producer_outputs = producer->MutableOutputDefs(); + std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input); + graph.UpdateProducerNode(cast_input.Name(), producer->Index()); } } - if (nullptr != producer) { - auto& producer_outputs = producer->MutableOutputDefs(); - // The following replacement is necessary in case where the output of the cast node is original - // output of the producer, for example the original output of the producer may be the graph output. - std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input); - graph.UpdateProducerNode(cast_input.Name(), producer->Index()); - int input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input); - graph.AddEdge(producer->Index(), cast.Index(), output_index, input_index); + // Update consumers of node_arg to use the output of the cast node + int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output); + for (Node* consumer : consumers) { + if (std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { + if (std::find(nodes.begin(), nodes.end(), consumer) == nodes.end()) { + // Consumers not getting casted need to replace input-def if the producer's output-def is changed + if (is_node_arg_cast_output) { + auto& consumer_inputs = consumer->MutableInputDefs(); + std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_output, &cast_input); + } + } else { + // Consumers getting casted need to get new edges from the new cast node.. + int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg); + if (!is_node_arg_cast_output) { + auto& consumer_inputs = consumer->MutableInputDefs(); + std::replace(consumer_inputs.begin(), consumer_inputs.end(), &cast_input, &cast_output); + } + graph.AddEdge(cast.Index(), consumer->Index(), cast_output_index, input_index); + } + } } + // Complete the input/output connections to the new cast node, and update the graph. + other_nodes.push_back(&cast); graph.UpdateProducerNode(cast_output.Name(), cast.Index()); + graph.UpdateConsumerNodes(cast_output.Name(), nodes); + graph.UpdateConsumerNodes(cast_input.Name(), other_nodes); + if (nullptr != producer) { + int cast_input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input); + int output_index = optimizer_utils::IndexOfNodeOutput(*producer, cast_input); + graph.AddEdge(producer->Index(), cast.Index(), output_index, cast_input_index); + if (is_node_arg_cast_output) { + graph.UpdateProducerNode(cast_input.Name(), producer->Index()); + } + } } return Status::OK(); } @@ -138,7 +177,7 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector casts, std:: if (producer) { if (graph.IsOutput(cast_output)) { // cast_output is a graph output. Replace the cast node with an Identity operator unless node - // has other outputs. + // has no other outputs. if (producer->GetOutputEdgesCount() == 1) { int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input); graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index); @@ -146,11 +185,11 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector casts, std:: std::replace(outputs.begin(), outputs.end(), cast_input, cast_output); graph.UpdateProducerNode(cast_output->Name(), producer->Index()); } else { - (void) graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"), - "Identity", - "Created as a place-holder for a graph output", - {cast_input}, - {cast_output}); + (void)graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"), + "Identity", + "Created as a place-holder for a graph output", + {cast_input}, + {cast_output}); } } } @@ -183,19 +222,21 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector casts, std:: // and the second cast is from FLOAT to FLOAT16. // Condition: The parent cast should have only one output // The inputs is Cast to FLOAT16 -static bool RemoveBackToBackCasts(Graph& graph, Node* node, +static bool RemoveBackToBackCasts(Graph& graph, Node* parent, std::deque& removed_nodes, const logging::Logger& logger) { - ORT_ENFORCE(IsCastTo(node, TensorProto::FLOAT)); + ORT_ENFORCE(IsCastTo(parent, TensorProto::FLOAT)); bool modified = false; - if (graph_utils::CanRemoveNode(graph, *node, logger)) { - NodeArg* cast_output = node->MutableOutputDefs()[0]; - for (Node* child : graph.GetMutableConsumerNodes(cast_output->Name())) { + if (graph_utils::CanRemoveNode(graph, *parent, logger)) { + NodeArg* cast_output = parent->MutableOutputDefs()[0]; + std::vector children = graph.GetMutableConsumerNodes(cast_output->Name()); + if (children.size() == 1) { + Node* child = children[0]; if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) { if (IsCastTo(child, TensorProto::FLOAT16)) { // The parent and child cancell out - LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << node->Name() << " and " << child->Name(); - RemoveCastNodesChain(graph, {node, child}, removed_nodes); + LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast nodes " << parent->Name() << " and " << child->Name(); + RemoveCastNodesChain(graph, {parent, child}, removed_nodes); modified = true; } else if (IsCastTo(child, TensorProto::FLOAT)) { // Child is a duplicate of parent @@ -204,6 +245,45 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node, modified = true; } } + } else { + NodeArg* parent_input = parent->MutableInputDefs()[0]; + const Node* producer = graph.GetProducerNode(parent_input->Name()); + int producer_output_index = producer ? optimizer_utils::IndexOfNodeOutput(*producer, *parent_input) : -1; + std::vector new_consumers; + for (Node* child : children) { + if (std::find(removed_nodes.begin(), removed_nodes.end(), child->Index()) == removed_nodes.end()) { + if (IsCastTo(child, TensorProto::FLOAT16)) { + // The parent and child cancell out + // Remove the child node without effecting the other nodes. + // move all the consumers to the producer. + LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name(); + NodeArg* child_output = child->MutableOutputDefs()[0]; + for (Node* consumer : graph.GetMutableConsumerNodes(child_output->Name())) { + std::vector& consumer_inputs = consumer->MutableInputDefs(); + int output_index = optimizer_utils::IndexOfNodeOutput(*child, *child_output); + int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *child_output); + graph.RemoveEdge(child->Index(), consumer->Index(), output_index, input_index); + std::replace(consumer_inputs.begin(), consumer_inputs.end(), child_output, parent_input); + if (nullptr != producer) { + graph.AddEdge(producer->Index(), consumer->Index(), producer_output_index, input_index); + } + new_consumers.push_back(consumer); + } + modified = true; + removed_nodes.push_front(child->Index()); + } else if (IsCastTo(child, TensorProto::FLOAT)) { + // Child is a duplicate of parent + LOGS(logger, VERBOSE) << "RemoveBackToBackCasts: Removed Cast node " << child->Name(); + RemoveCastNodesChain(graph, {child}, removed_nodes); + modified = true; + } + } + } + if (new_consumers.size() > 0) { + std::vector consumers = graph.GetMutableConsumerNodes(parent_input->Name()); + std::copy(new_consumers.begin(), new_consumers.end(), back_inserter(consumers)); + graph.UpdateConsumerNodes(parent_input->Name(), consumers); + } } } return modified; @@ -216,11 +296,18 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node, // nodearg is traversed not more than once. static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node, NodeArgToConsumerMap& require_cast, + NodeArgToConsumerMap& require_cast_fp32, std::unordered_set& require_type_change, std::deque& removed_nodes, size_t level) { Node* node = graph.GetMutableProducerNode(node_arg->Name()); - if (graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + // If the Cast input feeds more than one node or the cast node feeds a graph output and at least one + // node then it cannot propagate. + size_t consumer_node_count = graph.GetConsumerNodes(node_arg->Name()).size(); + if (level < 2 && (consumer_node_count > 1 || + nullptr != node && + consumer_node_count > 0 && + graph.IsOutput(node_arg))) { require_cast[node_arg].push_back(dst_node); } else if (node == nullptr) { // The graph inputs don't have the producer nodes @@ -244,18 +331,30 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node, // TODO: If the specified optimization is greater than 1 then insert a Cast to the // other output_def and still propagate FP16 cast up the graph. if (output_def != node_arg) { - if (IsType(*output_def, TensorProto_DataType_FLOAT)) { + if (IsType(*output_def, TensorProto_DataType_FLOAT) && graph.GetConsumerNodes(output_def->Name()).size() > 0) { require_cast[node_arg].push_back(dst_node); return; } } } + if (level >= 2) { + for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) { + if (nullptr != consumer && consumer != dst_node && consumer->OpType() != "Cast" && + std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { + require_cast_fp32[node_arg].push_back(consumer); + } + } + if (graph.IsOutput(node_arg)) { + require_cast_fp32[node_arg] = std::vector(); + } + } for (NodeArg* node_input : node->MutableInputDefs()) { if (IsType(*node_input, TensorProto_DataType_FLOAT) && require_cast.find(node_input) == require_cast.end() && require_type_change.find(node_input) == require_type_change.end()) { - SearchUpstream(graph, node_input, node, require_cast, require_type_change, removed_nodes, level); - if (require_cast.find(node_input) == require_cast.end()) { + SearchUpstream(graph, node_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level); + if (require_cast.find(node_input) == require_cast.end() && + require_cast_fp32.find(node_input) == require_cast_fp32.end()) { require_type_change.insert(node_input); } } @@ -318,17 +417,6 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg, } } -// ConcatNames -// Collects all the names from the pointers of the objects stores in the container class C -// the class should have a member functions returning a string (or a ref). -template -static std::string ConcatNames( - C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) { - std::vector names; - std::transform(items.begin(), items.end(), back_inserter(names), f); - return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; }); -} - // Change the elem_type of the given NodeArgs from FLOAT to FLOAT16. static void ChangeTypeToFP16(Graph& graph, std::unordered_set& require_type_change, bool is_forward, const logging::Logger& logger) { ONNX_NAMESPACE::TypeProto type_proto; @@ -401,25 +489,30 @@ static bool PropagateBackwards(Graph& graph, Node* node, bool modified = false; ORT_ENFORCE(nullptr != node); NodeArgToConsumerMap require_cast; + NodeArgToConsumerMap require_cast_fp32; NodeArg* cast_input = node->MutableInputDefs()[0]; - const Node* cast_input_producer = graph.GetProducerNode(cast_input->Name()); // nullptr for graph outputs - // If the Cast input feeds more than one node or the cast node feeds a graph output and at least one - // node then it cannot propagate. - size_t consumer_node_count = graph.GetConsumerNodes(cast_input->Name()).size(); - if (consumer_node_count > 1 || - (nullptr != cast_input_producer && graph.GetNodeOutputsInGraphOutputs(*cast_input_producer).size() > 0 && consumer_node_count > 0)) { - return modified; + std::unordered_set require_type_change; + if (graph.IsOutput(cast_input)) { + return false; } - std::unordered_set require_type_change = {cast_input}; - SearchUpstream(graph, cast_input, node, require_cast, require_type_change, removed_nodes, level); - if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end()) { + SearchUpstream(graph, cast_input, node, require_cast, require_cast_fp32, require_type_change, removed_nodes, level); + if (require_cast_fp32.empty()) { + require_type_change.insert(cast_input); + } + // TODO need a huristic when to insert FP32 Cast + if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end() /* && require_cast.size() >= require_cast_fp32.size() */) { // Remove Cast operation - LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name(); + if (require_cast_fp32.size() > 0) { + InsertCastNodes(graph, require_cast_fp32, false, removed_nodes); + LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted FP32 Cast nodes " + << ConcatNames(require_cast_fp32, GetName); + } RemoveCastNodesChain(graph, {node}, removed_nodes); + LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name(); InsertCastNodes(graph, require_cast, true, removed_nodes); - ChangeTypeToFP16(graph, require_type_change, false, logger); LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted Cast nodes " << ConcatNames(require_cast, GetName); + ChangeTypeToFP16(graph, require_type_change, false, logger); LOGS(logger, VERBOSE) << "PropagateBackwards: Changed the type from float to float16 : " << ConcatNames>(require_type_change); modified = true; @@ -585,6 +678,7 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, std::vector casts; // Cast nodes to propagate. std::vector& outputs = node->MutableOutputDefs(); std::unordered_set require_type_change; + NodeArgToConsumerMap non_cast_consumers_map; // TODO Here we require the all floating point outputs are consumer by an immediate // child cast node. for (auto iter = outputs.begin(); iter != outputs.end() && all_float_outputs_have_casts; ++iter) { @@ -594,22 +688,28 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, } has_float_outputs = true; std::vector consumers = graph.GetMutableConsumerNodes(output->Name()); - for (auto node_iter = consumers.begin(); node_iter != consumers.end() && all_float_outputs_have_casts; ++node_iter) { + for (auto node_iter = consumers.begin(); node_iter != consumers.end() && (level >= 2 || all_float_outputs_have_casts); ++node_iter) { Node* consumer = *node_iter; if (nullptr != consumer && - std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end() && - IsCastTo(consumer, TensorProto::FLOAT16)) { - casts.push_back(consumer); - continue; + std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { + if (IsCastTo(consumer, TensorProto::FLOAT16)) { + casts.push_back(consumer); + continue; + } else { + non_cast_consumers_map[output].push_back(consumer); + } } all_float_outputs_have_casts = false; } - require_type_change.insert(output); + if (non_cast_consumers_map.empty()) { + require_type_change.insert(output); + } } - if (has_float_outputs && all_float_outputs_have_casts && casts.size() > 1) { + if (has_float_outputs && (level >= 2 || all_float_outputs_have_casts) && casts.size() > 1) { LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Removed Cast nodes " << ConcatNames>(casts) - << " feeding the same compute node " << node->Name(); + << " feeding from the same compute node " << node->Name(); + InsertCastNodes(graph, non_cast_consumers_map, false, removed_nodes); for (Node* cast : casts) { RemoveCastNodesChain(graph, {cast}, removed_nodes); } @@ -623,6 +723,8 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, ChangeTypeToFP16(graph, require_type_change, false, logger); LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to " << ConcatNames(node_args_map, GetName); + LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted FP32 Cast node to " + << ConcatNames(non_cast_consumers_map, GetName); modified = true; } } @@ -705,16 +807,6 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level } } - // Propagate FP16 Casts backward - for (auto node_index : node_topology_list) { - Node* node = graph.GetNode(node_index); - if (nullptr != node && - std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() && - IsCastTo(node, TensorProto::FLOAT16)) { - local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger); - } - } - // Propagate FP16 Casts from outputs to inputs for (auto node_index : node_topology_list) { Node* node = graph.GetNode(node_index); @@ -724,6 +816,16 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level } } + // Propagate FP16 Casts + for (auto node_index : node_topology_list) { + Node* node = graph.GetNode(node_index); + if (nullptr != node && + std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end() && + IsCastTo(node, TensorProto::FLOAT16)) { + local_modified |= PropagateBackwards(graph, node, removed_nodes, level_, logger); + } + } + // Propagate FP32 Casts from inputs to outputs for (auto node_index : node_topology_list) { Node* node = graph.GetNode(node_index); @@ -755,7 +857,8 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level LOGS(logger, INFO) << "Nodes Converted to FP16:"; std::for_each(converted_node_names.begin(), converted_node_names.end(), [removed_node_names, logger](std::string name) { - if (removed_node_names.find(name) == removed_node_names.end()) { LOGS(logger, INFO) << name; } }); + if (removed_node_names.find(name) == removed_node_names.end() && inserted_node_names.find(name) == inserted_node_names.end()) { + LOGS(logger, INFO) << name; } }); } inserted_node_names.clear(); converted_node_names.clear(); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index aab654b0b7..4832428c35 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4182,7 +4182,30 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 1, allow_matmul_transpose}, {MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul.onnx", 2, allow_matmul}, {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose}, - {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}}; + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 2, allow_matmul}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs.onnx", 1, allow_matmul, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast.onnx", 1, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 3, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx", 3, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx", 3, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx", 2, allow_matmul_transpose, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx", 3, allow_matmul_transpose_add, 2}}; // Create a temporary directory, which will be deleted automatically, to save/load the transformed models. TemporaryDirectory temp_dir{ORT_TSTR("propagate_casts_test_output_dir")}; diff --git a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py index 41ef15af4e..1a44da9e82 100644 --- a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py +++ b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py @@ -135,57 +135,71 @@ def do_cast_inputs(input_0, input_1, nodes): to = input_cast_type)]) return "cast_"+input_0, "cast_"+input_1 def do_transpose_inputs(input_0, input_1, nodes): - nodes.extend([helper.make_node("Transpose", [input_0], ["transpose_"+input_0], "Transpose_0"), - helper.make_node("Transpose", [input_1], ["transpose_"+input_1], "Transpose_1")]) - return "transpose_"+input_0, "transpose_"+input_1 + nodes.extend([helper.make_node("Transpose", [input_0], ["input_transpose_0"], "Transpose_0"), + helper.make_node("Transpose", [input_1], ["input_transpose_1"], "Transpose_1")]) + return "input_transpose_0", "input_transpose_1" + def do_cast_product(product, nodes): - nodes.append(helper.make_node( + nodes.insert(1,helper.make_node( "Cast", [product], - ["cast" + product], + ["product_cast"], "Cast_2", to = TensorProto.FLOAT16)) - return "cast_"+product + return "product_cast" -def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2): - nodes = [ - helper.make_node( +def do_transpose_product(product, nodes): + if transpose_product: + nodes.append(helper.make_node("Transpose", [product], ["product_transpose"], "Transpose_2")) + return "product_transpose" + +def do_cast_sum(sum, nodes, type): + nodes.append(helper.make_node( + "Cast", + [sum], + ["cast_sum"], + "Cast_3", + to = type)) + return "cast_sum" + +def do_cast_input2(input_2, nodes, type): + nodes.append(helper.make_node( + "Cast", + [input_2], + ["cast_"+input_2], + "Cast_4", + to = type)) + return "cast_"+input_2 + +def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2, transpose_inputs_before_cast=False): + input_0 = "input_0" + input_1 = "input_1" + product = "product" + nodes = [] + if transpose_inputs_before_cast: + if transpose_inputs: + input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes) + if cast_inputs: + input_0, input_1 = do_cast_inputs(input_0, input_1, nodes) + else: + if cast_inputs: + input_0, input_1 = do_cast_inputs(input_0, input_1, nodes) + if transpose_inputs: + input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes) + nodes.append(helper.make_node( "MatMul", - ["input_transpose_0" if transpose_inputs else ("cast_input_0" if cast_inputs else "input_0"), - "input_transpose_1" if transpose_inputs else ("cast_input_1" if cast_inputs else "input_1")], - ["product"], + [input_0, + input_1], + [product], "MatMul_0") - ] + ) + if transpose_product: + product = do_transpose_product(product, nodes) if cast_product: - nodes.append(helper.make_node( - "Cast", - ["product_transpose" if transpose_product else "product"], - ["product_cast"], - "Cast_2", - to = TensorProto.FLOAT16)) + product = do_cast_product(product, nodes) - if cast_inputs: - input_cast_type = TensorProto.FLOAT - nodes.extend([helper.make_node( - "Cast", - ["input_0"], - ["cast_input_0"], - "Cast_0", - to = TensorProto.FLOAT), - helper.make_node( - "Cast", - ["input_1"], - ["cast_input_1"], - "Cast_1", - to = TensorProto.FLOAT)]) - - if transpose_inputs: - nodes.extend([helper.make_node("Transpose", ["cast_input_0" if cast_inputs else "input_0"], ["input_transpose_0"], "Transpose_0"), - helper.make_node("Transpose", ["cast_input_1" if cast_inputs else "input_1"], ["input_transpose_1"], "Transpose_1")]) - - if transpose_product: - nodes.append(helper.make_node("Transpose", ["product"], ["product_transpose"], "Transpose_2")) + output = product input_type = TensorProto.FLOAT16 if cast_inputs else TensorProto.FLOAT output_type = flip_type(cast_sum, flip_type(cast_product, flip_type(cast_inputs, input_type))) @@ -196,28 +210,21 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc "input_1", input_type, ['N', 'N']) ] if insert_add: + + input_2 = "input_2" add_input_type = flip_type(True, input_type) if cast_inputs != cast_product else input_type add_input_type = flip_type(cast_input2, add_input_type) - inputs.append(helper.make_tensor_value_info("input_2", add_input_type, ['N', 'N'])) - nodes.append(helper.make_node("Add", ["product_cast" if cast_product else ("product_transpose" if transpose_product else "product"), "cast_input_2" if cast_input2 else "input_2"], ["sum"], "Add_0")) - if cast_sum: - input2_cast_type = flip_type(True, flip_type(cast_input2, add_input_type)) - nodes.append(helper.make_node( - "Cast", - ["sum"], - ["cast_sum"], - "Cast_3", - to = input2_cast_type)) + inputs.append(helper.make_tensor_value_info(input_2, add_input_type, ['N', 'N'])) + add_output = "sum" if cast_input2: - nodes.append(helper.make_node( - "Cast", - ["input_2"], - ["cast_input_2"], - "Cast_4", - to = flip_type(True, add_input_type))) + input_2 = do_cast_input2(input_2, nodes, flip_type(True, add_input_type)) + nodes.append(helper.make_node("Add", [product, input_2], [add_output], "Add_0")) + if cast_sum: + add_output = do_cast_sum(add_output, nodes, flip_type(not cast_input2, add_input_type)) + output = add_output outputs = [ helper.make_tensor_value_info( - "cast_sum" if cast_sum else "sum" if insert_add else ("product_cast" if cast_product else ("product_transpose" if transpose_product else "product")), output_type, ['N', 'N']) + output, output_type, ['N', 'N']) ] save(model_path + ("_transpose_inputs" if transpose_inputs else "") + @@ -229,11 +236,12 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc nodes, inputs, outputs, []) def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul): - def do_transpose(output_0, output_1, nodes): - nodes.extend([helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"), - helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")]) + def do_transpose(output_0, output_1, transpose, nodes): + nodes.append(helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0")) output_0 = "transpose_0_"+output_0 - output_1 ="transpose_1_"+output_1 + if transpose > 1: + nodes.append(helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")) + output_1 ="transpose_1_"+output_1 return output_0, output_1 input_type = TensorProto.FLOAT input_0 = "input_0" @@ -262,9 +270,16 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second "MatMul_1")) outputs.append(helper.make_tensor_value_info( "second_"+output, input_type, ['M', 'N'])) - - if transpose and transpose_before_cast: - output_0, output_1 = do_transpose(output_0, output_1, nodes) + if add_products: + nodes.append(helper.make_node( + "Add", + [output, "second_"+output], + ["sum"], + "Add_0")) + outputs.append(helper.make_tensor_value_info( + "sum", input_type, ['M', 'N'])) + if transpose > 0 and transpose_before_cast: + output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes) nodes.append(helper.make_node( "Cast", @@ -283,8 +298,8 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second to = TensorProto.FLOAT16)) output_1 = "cast_1_"+output_1 - if transpose and not transpose_before_cast: - output_0, output_1 = do_transpose(output_0, output_1, nodes) + if transpose > 0 and not transpose_before_cast: + output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes) outputs.extend([ helper.make_tensor_value_info( @@ -292,8 +307,10 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second helper.make_tensor_value_info( output_1, flip_type(second_matmul, input_type), ['M', 'N']) ]) - model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose else "" + model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose > 0 else "" + model_path += "_transpose" if transpose > 1 else "" model_path += "_second_matmul" if second_matmul else "" + model_path += "_add_products" if add_products else "" save(model_path, nodes, inputs, outputs, []) for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)): @@ -305,7 +322,9 @@ for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, gen_fuse_sibling_casts("fuse_sibling_casts") gen_fuse_back2back_casts("fuse_back2back_casts") -for (transpose, transpose_before_cast, second_matmul) in list(itertools.product([False, True], repeat=3)): +for (transpose, transpose_before_cast, second_matmul, add_products) in list(itertools.product([0,1,2], [False, True], [False, True], [False, True])): if not transpose and transpose_before_cast: continue + if not second_matmul and add_products: + continue gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul) \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_input2_cast_sum.onnx index 9ce67f5f39..106aec6549 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs.onnx index b27d191789..ea1d434a27 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2.onnx index 203a93e9b8..8206173b97 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2_cast_sum.onnx index a28fc9f81f..f8cb8779c4 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product.onnx index a68951f94c..abcd7aabe3 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2.onnx index 8ccabe605f..001c4c84bb 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2_cast_sum.onnx index 03ad6d536f..ed4cbeceb6 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_sum.onnx index eb82eb9214..cae48b5e33 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_product_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_sum.onnx index 8652dea69d..aa53bfd132 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_inputs_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2.onnx index fc8d71a610..5002c8b5c2 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2_cast_sum.onnx index d4f232ceec..33b54f3f46 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_input2_cast_sum.onnx index 5bd616b5dc..25be3739cc 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs.onnx index c535eccccf..ff54f2230a 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2.onnx index a2818549e4..65c153db7f 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2_cast_sum.onnx index c482513f6a..5fd2cf1a6d 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product.onnx index 161e508735..cd6311c679 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2.onnx index cb06eb74f2..857a528478 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2_cast_sum.onnx index 2afd49dd0d..2a75be7bdb 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_sum.onnx index e3f2de633b..1abc0a1344 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_product_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_sum.onnx index 11d333f7f7..8a3d1f35cc 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_inputs_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product.onnx index 8d47b24d7e..66c8ad9452 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2.onnx index 1174c138c9..82fa7a8f50 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2_cast_sum.onnx index fbe2aadba5..edfbea8052 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_sum.onnx index 704bea31dc..2e308ec1dc 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_product_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_sum.onnx index 66aeb03cc3..16ac40b03d 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_input2_cast_sum.onnx index 5c0439bf1f..78d681cdfb 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs.onnx index 4fecbb3061..e9ac197963 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2.onnx index 3ffefa026a..965e835a4e 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2_cast_sum.onnx index fdbcba1efc..d88570d1ac 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx index fa56fa95ca..37fb80e858 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2.onnx index 58ef180193..5f312efea8 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx index 2e4809525b..e7eab92e67 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_sum.onnx index 0e1ea8a882..a10d146441 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_product_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_sum.onnx index b2da72a7f5..356b4d96ca 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_inputs_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product.onnx index 648350372b..4ecc7f1d10 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2.onnx index 8f176d7b26..ee3cd46188 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2_cast_sum.onnx index fe5891ce1b..809bbd7013 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_sum.onnx index df847546d6..280017d93c 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_product_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_sum.onnx index 2aacbde9ca..e38777792e 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_inputs_transpose_product_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_input2_cast_sum.onnx index 1d717ffdcc..1979008fc1 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs.onnx index 1645c1a173..1a2c650a87 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2.onnx index 814425a946..43b4db3ed6 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2_cast_sum.onnx index 251f8fb0fa..8793db8e81 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product.onnx index 2ecc683a30..aa46358eb8 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2.onnx index 38efdd2dc5..9f0a5a2798 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx index 66c5732087..98ed5f225d 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_sum.onnx index 5a0d319c61..3e4a41260d 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_sum.onnx index cda903017d..f59354f0b8 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_inputs_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx index c99b20d358..8038e6f447 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2_cast_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2_cast_sum.onnx index bf3e91dca9..49695f114a 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2_cast_sum.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_add_transpose_product_cast_product_cast_input2_cast_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs.onnx index b45fca5345..17e8e52759 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs_cast_product.onnx index cdce5480fa..fff0773085 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs.onnx index 71b5514525..8abd460127 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs_cast_product.onnx index f32d51d26e..f64a669d4b 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_product.onnx index 87d687b778..91c23995eb 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs.onnx index 8a6a12e419..ee821d286a 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx index 499cfa1f1d..874fec6ec6 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_product.onnx index 8af2623d1a..26137f8776 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_inputs_transpose_product_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs.onnx index 0fcc32d52b..9648bcc45e 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs_cast_product.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs_cast_product.onnx index e83e6d130c..4b84e3da3a 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs_cast_product.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_transpose_product_cast_inputs_cast_product.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx new file mode 100644 index 0000000000..baafe9ad8e Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_add_products.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx new file mode 100644 index 0000000000..aa80d71b6a Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul_sum.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx index 4cc5ae341f..2f8ffcedec 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx index cbeb4ad7f6..dceef36545 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx new file mode 100644 index 0000000000..640f179c0f Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul_add_products.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx new file mode 100644 index 0000000000..4cc5ae341f Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx new file mode 100644 index 0000000000..cbeb4ad7f6 Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx new file mode 100644 index 0000000000..db0fa9e956 Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_transpose_second_matmul_add_products.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx index a65e3a4e44..72219fbdde 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx index a92233aba4..6b0bd7053d 100644 Binary files a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx new file mode 100644 index 0000000000..3089ecbe34 Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul_add_products.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx new file mode 100644 index 0000000000..a65e3a4e44 Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx new file mode 100644 index 0000000000..a92233aba4 Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx new file mode 100644 index 0000000000..1c4d092d89 Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_transpose_second_matmul_add_products.onnx differ