From bb1e417da0abb79e711ae49737b8fc8642b5b753 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Mon, 19 Apr 2021 12:14:30 -0700 Subject: [PATCH] Add logging support to Cast Propagation transformation from python (#7353) * Fixes needed to PropagateCast transformation. * Added number of passes to the logs. * Added logging support to OrtModuleGraphBuilder. * Added new testcases. * Added NodeArgToConsumerMap --- .../core/optimizer/propagate_cast_ops.cc | 143 +++++++++++------- .../test/optimizer/graph_transform_test.cc | 8 +- .../propagate_cast/gen_propagate_cast.py | 100 ++++++++++++ .../propagate_cast/matmul_two_outputs.onnx | Bin 0 -> 289 bytes .../matmul_two_outputs_second_matmul.onnx | Bin 0 -> 438 bytes ...tmul_two_outputs_transpose_after_cast.onnx | Bin 0 -> 439 bytes ...ts_transpose_after_cast_second_matmul.onnx | Bin 0 -> 602 bytes ...mul_two_outputs_transpose_before_cast.onnx | Bin 0 -> 449 bytes ...s_transpose_before_cast_second_matmul.onnx | Bin 0 -> 622 bytes .../core/framework/ortmodule_graph_builder.cc | 4 +- .../core/framework/ortmodule_graph_builder.h | 5 +- .../python/orttraining_pybind_state.cc | 13 +- .../_ortmodule_graph_execution_manager.py | 5 + 13 files changed, 221 insertions(+), 57 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index 01df60e48e..c3af6a8aae 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -9,6 +9,12 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { +// NodeArg to Select consumer node map. +typedef std::unordered_map> NodeArgToConsumerMap; +static std::string GetName(const std::pair>& p) { + return p.first->Name(); +}; + // The collection fp16_allow_ops, specifies for a given propagate_cast_ops level, a vector of node op_types that // the code is allowed to propage Cast operations cross. The user may specify a custom list of optypes using level 0. // The opcodes are split into multiple levels. Cast propagation is done based on the level. Level 2 op code @@ -42,14 +48,17 @@ static bool IsType(const NodeArg& node_arg, TensorProto_DataType data_type) { } // InsertCastNodes -// Insert a new Cast node after each NodeArg in the require_cast vector. The cast node is FLOAT16 if is_fp16 is True +// Insert a new Cast node after each NodeArg in the require_cast map, feeding the nodes in the vector mapped to +// the NodeArg. The other consumers of the NodeArg will not be changed. The cast node is FLOAT16 if is_fp16 is True // and FLOAT otherwise. This funtion fixes the graph edges in addition to inserting the cast nodes. static Status InsertCastNodes(Graph& graph, - const std::unordered_set& require_cast, + const NodeArgToConsumerMap& require_cast, bool is_fp16, std::deque& removed_nodes) { //Create requirred new Cast nodes. - for (NodeArg* node_arg : require_cast) { + for (std::pair> element : require_cast) { + NodeArg* node_arg = element.first; + std::vector nodes = element.second; if (!node_arg->Exists()) { continue; } @@ -89,7 +98,7 @@ static Status InsertCastNodes(Graph& graph, // Update consumers of node_arg to use the output of the cast node int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output); for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) { - if (nullptr != consumer && + if (nullptr != consumer && std::find(nodes.begin(), nodes.end(), consumer) != nodes.end() && std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) { auto& consumer_inputs = consumer->MutableInputDefs(); int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg); @@ -102,6 +111,8 @@ static Status InsertCastNodes(Graph& graph, } if (nullptr != producer) { auto& producer_outputs = producer->MutableOutputDefs(); + // The following replacement is necessary in case where the output of the cast node is original + // output of the producer, for example the original output of the producer may be the graph output. std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input); graph.UpdateProducerNode(cast_input.Name(), producer->Index()); int input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input); @@ -125,12 +136,22 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector casts, std:: auto consumers = graph.GetMutableConsumerNodes(cast_output->Name()); int output_index = (nullptr != producer) ? optimizer_utils::IndexOfNodeOutput(*producer, *cast_input) : -1; if (producer) { - int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input); - graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index); - if (consumers.empty()) { - auto& outputs = producer->MutableOutputDefs(); - std::replace(outputs.begin(), outputs.end(), cast_input, cast_output); - graph.UpdateProducerNode(cast_output->Name(), producer->Index()); + if (graph.IsOutput(cast_output)) { + // cast_output is a graph output. Replace the cast node with an Identity operator unless node + // has other outputs. + if (producer->GetOutputEdgesCount() == 1) { + int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input); + graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index); + auto& outputs = producer->MutableOutputDefs(); + std::replace(outputs.begin(), outputs.end(), cast_input, cast_output); + graph.UpdateProducerNode(cast_output->Name(), producer->Index()); + } else { + (void) graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"), + "Identity", + "Created as a place-holder for a graph output", + {cast_input}, + {cast_output}); + } } } // Update consumer nodes @@ -193,27 +214,29 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node, // inorder to move an FP16 Cast operation up the graph. // Visited float NodeArgs are either in require_cast or require_type_change so that the same // nodearg is traversed not more than once. -static void SearchUpstream(Graph& graph, NodeArg* node_arg, - std::unordered_set& require_cast, +static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node, + NodeArgToConsumerMap& require_cast, std::unordered_set& require_type_change, std::deque& removed_nodes, size_t level) { Node* node = graph.GetMutableProducerNode(node_arg->Name()); - if (node == nullptr) { + if (graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + require_cast[node_arg].push_back(dst_node); + } else if (node == nullptr) { // The graph inputs don't have the producer nodes if (IsType(*node_arg, TensorProto_DataType_FLOAT)) { - require_cast.insert(node_arg); + require_cast[node_arg].push_back(dst_node); } } else if (std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end()) { if (IsCastTo(node, TensorProto_DataType_FLOAT)) { // This Cast node and the Cast node that will be created later will cancel out - require_cast.insert(node_arg); + require_cast[node_arg].push_back(dst_node); } else { std::string op_type = node->OpType(); if (!IsFP16Allow(op_type, level)) { // Cannot traverse-up beyond this point if (node_arg->Exists() && IsType(*node_arg, TensorProto_DataType_FLOAT)) { - require_cast.insert(node_arg); + require_cast[node_arg].push_back(dst_node); } } else { // If the node has other float32 output(s) then stop the search. @@ -222,7 +245,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg, // other output_def and still propagate FP16 cast up the graph. if (output_def != node_arg) { if (IsType(*output_def, TensorProto_DataType_FLOAT)) { - require_cast.insert(node_arg); + require_cast[node_arg].push_back(dst_node); return; } } @@ -231,7 +254,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg, if (IsType(*node_input, TensorProto_DataType_FLOAT) && require_cast.find(node_input) == require_cast.end() && require_type_change.find(node_input) == require_type_change.end()) { - SearchUpstream(graph, node_input, require_cast, require_type_change, removed_nodes, level); + SearchUpstream(graph, node_input, node, require_cast, require_type_change, removed_nodes, level); if (require_cast.find(node_input) == require_cast.end()) { require_type_change.insert(node_input); } @@ -248,7 +271,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg, // be converted from float to float16 along the way. // The recursion only traverses an static void SearchDownstream(Graph& graph, NodeArg* node_arg, - std::unordered_set& require_cast, + NodeArgToConsumerMap& require_cast, std::unordered_set& require_type_change, std::deque& removed_nodes, size_t level) { @@ -257,21 +280,21 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg, std::string op_type = node->OpType(); if (IsCastTo(node, TensorProto_DataType_FLOAT)) { // This Cast node and the Cast node that will be created later will cancel out - require_cast.insert(node_arg); + require_cast[node_arg].push_back(node); } else { if (!IsFP16Allow(op_type, level)) { if (node_arg->Exists() && IsType(*node_arg, TensorProto_DataType_FLOAT)) { - require_cast.insert(node_arg); + require_cast[node_arg].push_back(node); } } else { // If the node has other float32 inputs then stop the search for (const auto* input_def : node->InputDefs()) { - // TODO: If the secified level of the optimization is greater than 1 then + // TODO: If the specified level of the optimization is greater than 1 then // convert initializers if any from float to float16. if (input_def != node_arg) { if (IsType(*input_def, TensorProto_DataType_FLOAT)) { - require_cast.insert(node_arg); + require_cast[node_arg].push_back(node); return; } } @@ -290,8 +313,8 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg, } } } - if (graph.IsOutput(node_arg)) { - require_cast.insert(node_arg); + if (graph.IsOutput(node_arg) && require_cast.find(node_arg) == require_cast.end()) { + require_cast.insert(std::make_pair(node_arg, std::vector())); } } @@ -299,25 +322,31 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg, // Collects all the names from the pointers of the objects stores in the container class C // the class should have a member functions returning a string (or a ref). template -static std::string ConcatNames(C const& items) { +static std::string ConcatNames( + C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) { std::vector names; - std::transform(items.begin(), items.end(), back_inserter(names), [](T n) { return n->Name(); }); + std::transform(items.begin(), items.end(), back_inserter(names), f); return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; }); } // Change the elem_type of the given NodeArgs from FLOAT to FLOAT16. -static void ChangeTypeToFP16(Graph& graph, std::unordered_set& require_type_change, const logging::Logger& logger) { +static void ChangeTypeToFP16(Graph& graph, std::unordered_set& require_type_change, bool is_forward, const logging::Logger& logger) { ONNX_NAMESPACE::TypeProto type_proto; type_proto.mutable_tensor_type()->set_elem_type(TensorProto::FLOAT16); for (NodeArg* node_arg : require_type_change) { if (IsType(*node_arg, TensorProto::FLOAT)) { node_arg->UpdateTypeAndShape(type_proto, true, true, logger); - for (const Node* node : graph.GetConsumerNodes(node_arg->Name())) { - converted_node_names.insert(node->Name()); - } - const Node* producer = graph.GetProducerNode(node_arg->Name()); - if (nullptr != producer) { - converted_node_names.insert(producer->Name()); + if (is_forward) { + // Propagating forwards. Count consumers. + for (const Node* node : graph.GetConsumerNodes(node_arg->Name())) { + converted_node_names.insert(node->Name()); + } + } else { + // Propagating backwards. Count producers. + const Node* producer = graph.GetProducerNode(node_arg->Name()); + if (nullptr != producer) { + converted_node_names.insert(producer->Name()); + } } } } @@ -338,7 +367,7 @@ static bool PropagateForwards(Graph& graph, Node* node, const logging::Logger& logger) { ORT_ENFORCE(nullptr != node); bool modified = false; - std::unordered_set require_cast; + NodeArgToConsumerMap require_cast; std::unordered_set require_type_change; NodeArg* cast_output = node->MutableOutputDefs()[0]; SearchDownstream(graph, cast_output, require_cast, require_type_change, removed_nodes, level); @@ -347,8 +376,9 @@ static bool PropagateForwards(Graph& graph, Node* node, LOGS(logger, VERBOSE) << "PropagateForwards: Removed Cast node " << node->Name(); RemoveCastNodesChain(graph, {node}, removed_nodes); InsertCastNodes(graph, require_cast, false, removed_nodes); - ChangeTypeToFP16(graph, require_type_change, logger); - LOGS(logger, VERBOSE) << "PropagateForwwards: Inserted Cast nodes " << ConcatNames>(require_cast); + ChangeTypeToFP16(graph, require_type_change, true, logger); + LOGS(logger, VERBOSE) << "PropagateForwwards: Inserted Cast nodes " + << ConcatNames(require_cast, GetName); modified = true; } return modified; @@ -370,18 +400,26 @@ static bool PropagateBackwards(Graph& graph, Node* node, const logging::Logger& logger) { bool modified = false; ORT_ENFORCE(nullptr != node); - std::unordered_set require_cast; + NodeArgToConsumerMap require_cast; NodeArg* cast_input = node->MutableInputDefs()[0]; + const Node* cast_input_producer = graph.GetProducerNode(cast_input->Name()); // nullptr for graph outputs + // If the Cast input feeds more than one node or the cast node feeds a graph output and at least one + // node then it cannot propagate. + size_t consumer_node_count = graph.GetConsumerNodes(cast_input->Name()).size(); + if (consumer_node_count > 1 || + (nullptr != cast_input_producer && graph.GetNodeOutputsInGraphOutputs(*cast_input_producer).size() > 0 && consumer_node_count > 0)) { + return modified; + } std::unordered_set require_type_change = {cast_input}; - SearchUpstream(graph, cast_input, require_cast, require_type_change, removed_nodes, level); + SearchUpstream(graph, cast_input, node, require_cast, require_type_change, removed_nodes, level); if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end()) { // Remove Cast operation LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name(); RemoveCastNodesChain(graph, {node}, removed_nodes); InsertCastNodes(graph, require_cast, true, removed_nodes); - ChangeTypeToFP16(graph, require_type_change, logger); + ChangeTypeToFP16(graph, require_type_change, false, logger); LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted Cast nodes " - << ConcatNames>(require_cast); + << ConcatNames(require_cast, GetName); LOGS(logger, VERBOSE) << "PropagateBackwards: Changed the type from float to float16 : " << ConcatNames>(require_type_change); modified = true; @@ -515,16 +553,17 @@ static bool PropagateFP32CastsFromInputsToOutputs(Graph& graph, Node* node, for (Node* cast : casts) { RemoveCastNodesChain(graph, {cast}, removed_nodes); } - std::unordered_set node_args; + NodeArgToConsumerMap node_args_map; for (NodeArg* output : node->MutableOutputDefs()) { if (output->Exists() && IsType(*output, TensorProto::FLOAT)) { - node_args.insert(output); + node_args_map.insert(std::make_pair(output, graph.GetMutableConsumerNodes(output->Name()))); } } - InsertCastNodes(graph, node_args, false, removed_nodes); - ChangeTypeToFP16(graph, require_type_change, logger); + InsertCastNodes(graph, node_args_map, false, removed_nodes); + ChangeTypeToFP16(graph, require_type_change, true, logger); + LOGS(logger, VERBOSE) << "PropagateFP32CastsFromInputsToOutputs: Inserted Cast node to " - << ConcatNames(node_args); + << ConcatNames(node_args_map, GetName); modified = true; } } @@ -574,15 +613,16 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, for (Node* cast : casts) { RemoveCastNodesChain(graph, {cast}, removed_nodes); } - std::unordered_set node_args; + NodeArgToConsumerMap node_args_map; for (NodeArg* input : node->MutableInputDefs()) { if (IsType(*input, TensorProto::FLOAT)) { - node_args.insert(input); + node_args_map.insert(std::make_pair(input, std::vector({node}))); } } - InsertCastNodes(graph, node_args, true, removed_nodes); - ChangeTypeToFP16(graph, require_type_change, logger); - LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to " << ConcatNames(node_args); + InsertCastNodes(graph, node_args_map, true, removed_nodes); + ChangeTypeToFP16(graph, require_type_change, false, logger); + LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to " + << ConcatNames(node_args_map, GetName); modified = true; } } @@ -704,6 +744,7 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level // Generate summary if the graph is modified if (modified) { LOGS(logger, INFO) << "Propagate Cast operations summary:"; + LOGS(logger, INFO) << "Number of passes = " << pass; LOGS(logger, INFO) << "Nodes Inserted:"; std::for_each(inserted_node_names.begin(), inserted_node_names.end(), [removed_node_names, logger](std::string name) { if (removed_node_names.find(name) == removed_node_names.end()) { LOGS(logger, INFO) << name; } }); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 9fe75c4e29..5a98c9d2db 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -3968,7 +3968,13 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { {MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx", 2, allow_matmul_transpose_add}, {MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2.onnx", 1, allow_matmul_transpose_add}, {MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2_cast_sum.onnx", 4, allow_matmul_transpose_add}, - {MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx", 3, allow_matmul_transpose_add}}; + {MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx", 3, allow_matmul_transpose_add}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs.onnx", 1, allow_matmul}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 1, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul.onnx", 2, allow_matmul}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose}, + {MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}}; // Create a temporary directory, which will be deleted automatically, to save/load the transformed models. TemporaryDirectory temp_dir{ORT_TSTR("propagate_casts_test_output_dir")}; diff --git a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py index 0f91918644..41ef15af4e 100644 --- a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py +++ b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py @@ -119,6 +119,33 @@ def gen_fuse_sibling_casts(model_path): def flip_type(flip, type): return (TensorProto.FLOAT16 if type == TensorProto.FLOAT else TensorProto.FLOAT) if flip else type +def do_cast_inputs(input_0, input_1, nodes): + input_cast_type = TensorProto.FLOAT + nodes.extend([helper.make_node( + "Cast", + [input_0], + ["cast_"+input_0], + "Cast_0", + to = input_cast_type), + helper.make_node( + "Cast", + [input_1], + ["cast_"+input_1], + "Cast_1", + to = input_cast_type)]) + return "cast_"+input_0, "cast_"+input_1 +def do_transpose_inputs(input_0, input_1, nodes): + nodes.extend([helper.make_node("Transpose", [input_0], ["transpose_"+input_0], "Transpose_0"), + helper.make_node("Transpose", [input_1], ["transpose_"+input_1], "Transpose_1")]) + return "transpose_"+input_0, "transpose_"+input_1 +def do_cast_product(product, nodes): + nodes.append(helper.make_node( + "Cast", + [product], + ["cast" + product], + "Cast_2", + to = TensorProto.FLOAT16)) + return "cast_"+product def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2): nodes = [ @@ -201,6 +228,74 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc ("_cast_sum" if cast_sum else ""), nodes, inputs, outputs, []) +def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul): + def do_transpose(output_0, output_1, nodes): + nodes.extend([helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"), + helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")]) + output_0 = "transpose_0_"+output_0 + output_1 ="transpose_1_"+output_1 + return output_0, output_1 + input_type = TensorProto.FLOAT + input_0 = "input_0" + input_1 = "input_1" + output = "product" + output_0 = "product" + output_1 = "product" + inputs = [ + helper.make_tensor_value_info( + input_0, input_type, ['M', 'K']), + helper.make_tensor_value_info( + input_1, input_type, ['K', 'N']) + ] + outputs = [] + nodes = [ + helper.make_node( + "MatMul", + [input_0, input_1], + [output], + "MatMul_0")] + if second_matmul: + nodes.append(helper.make_node( + "MatMul", + [input_0, input_1], + ["second_"+output], + "MatMul_1")) + outputs.append(helper.make_tensor_value_info( + "second_"+output, input_type, ['M', 'N'])) + + if transpose and transpose_before_cast: + output_0, output_1 = do_transpose(output_0, output_1, nodes) + + nodes.append(helper.make_node( + "Cast", + [output_0], + ["cast_0_"+output_0], + "Cast_0", + to = TensorProto.FLOAT16)) + output_0 = "cast_0_"+output_0 + + if second_matmul: + nodes.append(helper.make_node( + "Cast", + [output_1], + ["cast_1_"+output_1], + "Cast_1", + to = TensorProto.FLOAT16)) + output_1 = "cast_1_"+output_1 + + if transpose and not transpose_before_cast: + output_0, output_1 = do_transpose(output_0, output_1, nodes) + + outputs.extend([ + helper.make_tensor_value_info( + output_0, flip_type(True, input_type), ['M', 'N']), + helper.make_tensor_value_info( + output_1, flip_type(second_matmul, input_type), ['M', 'N']) + ]) + model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose else "" + model_path += "_second_matmul" if second_matmul else "" + save(model_path, nodes, inputs, outputs, []) + for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)): if not insert_add and (cast_sum or cast_input2): continue @@ -209,3 +304,8 @@ for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, gen_fuse_sibling_casts("fuse_sibling_casts") gen_fuse_back2back_casts("fuse_back2back_casts") + +for (transpose, transpose_before_cast, second_matmul) in list(itertools.product([False, True], repeat=3)): + if not transpose and transpose_before_cast: + continue + gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul) \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b1287d3fa815875725e8616cc2d805dd4742d1b7 GIT binary patch literal 289 zcmd;J7ZS+N%d03V%`3^wP1P+)EiSS8$jGJ3#h#g0P+Agi0Am;mu@@BOr<5j_NOAZk zmiU(D#2YBFL0DWyTrf!?zU0JWptg9JVm4 kbriAzVAK7aShyGjc$@^ec$4#U^>Q>PIt%I|L!8#0lqcnIUWzMrb#LTaTpA5J{hqAjmlZMfnAZ>4_z&AwX|O zNyEHqASA%W$H6GX#leWd^Z literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cbeb4ad7f649641703b585cee2be524e3efd3826 GIT binary patch literal 602 zcmaKoy=ucS6ou>5#LjJN98JO9JQfn7(y3#bEjV<@l*KrTS`4*hq-*Hobna8ucCaF3 z=%OFaIp5*lFd(rM;#)NW%Ze|NYkd9nVF^Jg>IU=m@Sr59l{_~LlQ1hVYrgVzI=#H$ zWpG*cNOLB|Id@akq@H0wl86<$Ei>lnTmMh}?tYrV1myyLya|jl)XC6sW(y9UB&Z6Z zYpHqL=~Ukxu39=X9~da6(5BD}#b}8|p%#VAyMf-7tc%YA`Jr9+@zE@g#1MrZ0k|bz z)*ZX6NNsV~iThIu&XR3U57?hp@XxTlAFM*h?~XP0Z+y6k?>BRJV6s|OB~wz%4~)G( DGX1Ts literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a65e3a4e44613114c89f9b508e607915ddb99445 GIT binary patch literal 449 zcmd;J7ZS+N%d03V%`3^wP1P+)EiSQ|$jqh7#h#g0P+Agi0Am;mu@@BOr<5j_NOAZk zmiU(D#2YBFL0DX7Trf!?;gX`nyyAlV;?#J9c$g0E5QLx-C!B*rpCLAVhRFK7xUg9x zB$b?440LilHYcz-gXNT1K(rPo7gI^T1lIyaCLuwPKtNG`L1KDhNooi%0HUN}o;DB? z;Ns(86yoAy7Gm@T5#9(%Lu5&BkfdLd9v2=rfYoy$8RnOyj_eV*zmW|F+v?}U!o?uK Y<0Qz%o1CAkmz$YflwX{mRwBR%0D3@zD*ylh literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx b/onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a92233aba4f2cc860f98ba17a1487bf38348b7bb GIT binary patch literal 622 zcmaKpy-ve06opBNn)CuyY8Di>j7Y6w2{AC1jVcvGCl=YvC_S*7ThkzS)1l>#coIl)y9& z5WqPKigs8HB6h`9E3UU096h}2=>Y5P0{-c_^Mg}x@O%3@^)Ev>OD>l)xMH$eR3%eV I%Quq*Kcm~Wxc~qF literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index b8bdd2547b..c6b0116c90 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -58,7 +58,7 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream, } graph.SetInputs(input_args); - graph_transformer_config_ = config.graph_transformer_config; + logging::LoggingManager::SetDefaultLoggerSeverity(config_.loglevel); return Status::OK(); } @@ -154,7 +154,7 @@ Status OrtModuleGraphBuilder::OptimizeInferenceGraph(std::unordered_set propagate_cast_ops_allow; // graph dumping @@ -520,6 +520,14 @@ py::class_(m, "TrainingAgent", R"pbdoc(This is the main class use py::class_ module_graph_builder_config( m, "OrtModuleGraphBuilderConfiguration", R"pbdoc(Configuration information for module graph builder.)pbdoc"); + + py::enum_(m, "Severity", py::arithmetic(), py::module_local()) + .value("VERBOSE", logging::Severity::kVERBOSE) + .value("INFO", logging::Severity::kINFO) + .value("WARNING", logging::Severity::kWARNING) + .value("ERROR", logging::Severity::kERROR) + .value("FATAL", logging::Severity::kFATAL); + module_graph_builder_config.def(py::init()) .def_readwrite("initializer_names", &OrtModuleGraphBuilderConfiguration::initializer_names) .def_readwrite("initializer_names_to_train", &OrtModuleGraphBuilderConfiguration::initializer_names_to_train) @@ -527,7 +535,8 @@ py::class_(m, "TrainingAgent", R"pbdoc(This is the main class use .def_readwrite("use_invertible_layernorm_grad", &OrtModuleGraphBuilderConfiguration::use_invertible_layernorm_grad) .def_readwrite("build_gradient_graph", &OrtModuleGraphBuilderConfiguration::build_gradient_graph) - .def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config); + .def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config) + .def_readwrite("loglevel", &OrtModuleGraphBuilderConfiguration::loglevel); py::class_ graph_info(m, "GraphInfo", R"pbdoc(The information of split graphs for frontend.)pbdoc"); diff --git a/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py b/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py index 0623c092ad..2e02c0e32a 100644 --- a/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py @@ -273,5 +273,10 @@ class GraphExecutionManager(ABC): grad_builder_config.graph_transformer_config = C.GraphTransformerConfiguration() grad_builder_config.graph_transformer_config.propagate_cast_ops_level = self._propagate_cast_ops_level grad_builder_config.graph_transformer_config.propagate_cast_ops_allow = self._propagate_cast_ops_allow + grad_builder_config.loglevel = {_logger.LogLevel.VERBOSE : C.Severity.VERBOSE, + _logger.LogLevel.INFO : C.Severity.INFO, + _logger.LogLevel.WARNING : C.Severity.WARNING, + _logger.LogLevel.ERROR : C.Severity.ERROR, + _logger.LogLevel.FATAL : C.Severity.FATAL}.get(self._loglevel, C.Severity.WARNING) self._graph_builder = C.OrtModuleGraphBuilder() self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)