mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Add logging support to Cast Propagation transformation from python (#7353)
* Fixes needed to PropagateCast transformation. * Added number of passes to the logs. * Added logging support to OrtModuleGraphBuilder. * Added new testcases. * Added NodeArgToConsumerMap
This commit is contained in:
parent
6dda1e0681
commit
bb1e417da0
13 changed files with 221 additions and 57 deletions
|
|
@ -9,6 +9,12 @@
|
|||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
// NodeArg to Select consumer node map.
|
||||
typedef std::unordered_map<NodeArg*, std::vector<Node*>> NodeArgToConsumerMap;
|
||||
static std::string GetName(const std::pair<const NodeArg*, std::vector<Node*>>& p) {
|
||||
return p.first->Name();
|
||||
};
|
||||
|
||||
// The collection fp16_allow_ops, specifies for a given propagate_cast_ops level, a vector of node op_types that
|
||||
// the code is allowed to propage Cast operations cross. The user may specify a custom list of optypes using level 0.
|
||||
// The opcodes are split into multiple levels. Cast propagation is done based on the level. Level 2 op code
|
||||
|
|
@ -42,14 +48,17 @@ static bool IsType(const NodeArg& node_arg, TensorProto_DataType data_type) {
|
|||
}
|
||||
|
||||
// InsertCastNodes
|
||||
// Insert a new Cast node after each NodeArg in the require_cast vector. The cast node is FLOAT16 if is_fp16 is True
|
||||
// Insert a new Cast node after each NodeArg in the require_cast map, feeding the nodes in the vector mapped to
|
||||
// the NodeArg. The other consumers of the NodeArg will not be changed. The cast node is FLOAT16 if is_fp16 is True
|
||||
// and FLOAT otherwise. This funtion fixes the graph edges in addition to inserting the cast nodes.
|
||||
static Status InsertCastNodes(Graph& graph,
|
||||
const std::unordered_set<NodeArg*>& require_cast,
|
||||
const NodeArgToConsumerMap& require_cast,
|
||||
bool is_fp16,
|
||||
std::deque<NodeIndex>& removed_nodes) {
|
||||
//Create requirred new Cast nodes.
|
||||
for (NodeArg* node_arg : require_cast) {
|
||||
for (std::pair<NodeArg*, std::vector<Node*>> element : require_cast) {
|
||||
NodeArg* node_arg = element.first;
|
||||
std::vector<Node*> nodes = element.second;
|
||||
if (!node_arg->Exists()) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -89,7 +98,7 @@ static Status InsertCastNodes(Graph& graph,
|
|||
// Update consumers of node_arg to use the output of the cast node
|
||||
int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output);
|
||||
for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) {
|
||||
if (nullptr != consumer &&
|
||||
if (nullptr != consumer && std::find(nodes.begin(), nodes.end(), consumer) != nodes.end() &&
|
||||
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
|
||||
auto& consumer_inputs = consumer->MutableInputDefs();
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
|
||||
|
|
@ -102,6 +111,8 @@ static Status InsertCastNodes(Graph& graph,
|
|||
}
|
||||
if (nullptr != producer) {
|
||||
auto& producer_outputs = producer->MutableOutputDefs();
|
||||
// The following replacement is necessary in case where the output of the cast node is original
|
||||
// output of the producer, for example the original output of the producer may be the graph output.
|
||||
std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input);
|
||||
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input);
|
||||
|
|
@ -125,12 +136,22 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
|
|||
auto consumers = graph.GetMutableConsumerNodes(cast_output->Name());
|
||||
int output_index = (nullptr != producer) ? optimizer_utils::IndexOfNodeOutput(*producer, *cast_input) : -1;
|
||||
if (producer) {
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input);
|
||||
graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index);
|
||||
if (consumers.empty()) {
|
||||
auto& outputs = producer->MutableOutputDefs();
|
||||
std::replace(outputs.begin(), outputs.end(), cast_input, cast_output);
|
||||
graph.UpdateProducerNode(cast_output->Name(), producer->Index());
|
||||
if (graph.IsOutput(cast_output)) {
|
||||
// cast_output is a graph output. Replace the cast node with an Identity operator unless node
|
||||
// has other outputs.
|
||||
if (producer->GetOutputEdgesCount() == 1) {
|
||||
int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input);
|
||||
graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index);
|
||||
auto& outputs = producer->MutableOutputDefs();
|
||||
std::replace(outputs.begin(), outputs.end(), cast_input, cast_output);
|
||||
graph.UpdateProducerNode(cast_output->Name(), producer->Index());
|
||||
} else {
|
||||
(void) graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"),
|
||||
"Identity",
|
||||
"Created as a place-holder for a graph output",
|
||||
{cast_input},
|
||||
{cast_output});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update consumer nodes
|
||||
|
|
@ -193,27 +214,29 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node,
|
|||
// inorder to move an FP16 Cast operation up the graph.
|
||||
// Visited float NodeArgs are either in require_cast or require_type_change so that the same
|
||||
// nodearg is traversed not more than once.
|
||||
static void SearchUpstream(Graph& graph, NodeArg* node_arg,
|
||||
std::unordered_set<NodeArg*>& require_cast,
|
||||
static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node,
|
||||
NodeArgToConsumerMap& require_cast,
|
||||
std::unordered_set<NodeArg*>& require_type_change,
|
||||
std::deque<NodeIndex>& removed_nodes,
|
||||
size_t level) {
|
||||
Node* node = graph.GetMutableProducerNode(node_arg->Name());
|
||||
if (node == nullptr) {
|
||||
if (graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
|
||||
require_cast[node_arg].push_back(dst_node);
|
||||
} else if (node == nullptr) {
|
||||
// The graph inputs don't have the producer nodes
|
||||
if (IsType(*node_arg, TensorProto_DataType_FLOAT)) {
|
||||
require_cast.insert(node_arg);
|
||||
require_cast[node_arg].push_back(dst_node);
|
||||
}
|
||||
} else if (std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end()) {
|
||||
if (IsCastTo(node, TensorProto_DataType_FLOAT)) {
|
||||
// This Cast node and the Cast node that will be created later will cancel out
|
||||
require_cast.insert(node_arg);
|
||||
require_cast[node_arg].push_back(dst_node);
|
||||
} else {
|
||||
std::string op_type = node->OpType();
|
||||
if (!IsFP16Allow(op_type, level)) {
|
||||
// Cannot traverse-up beyond this point
|
||||
if (node_arg->Exists() && IsType(*node_arg, TensorProto_DataType_FLOAT)) {
|
||||
require_cast.insert(node_arg);
|
||||
require_cast[node_arg].push_back(dst_node);
|
||||
}
|
||||
} else {
|
||||
// If the node has other float32 output(s) then stop the search.
|
||||
|
|
@ -222,7 +245,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg,
|
|||
// other output_def and still propagate FP16 cast up the graph.
|
||||
if (output_def != node_arg) {
|
||||
if (IsType(*output_def, TensorProto_DataType_FLOAT)) {
|
||||
require_cast.insert(node_arg);
|
||||
require_cast[node_arg].push_back(dst_node);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -231,7 +254,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg,
|
|||
if (IsType(*node_input, TensorProto_DataType_FLOAT) &&
|
||||
require_cast.find(node_input) == require_cast.end() &&
|
||||
require_type_change.find(node_input) == require_type_change.end()) {
|
||||
SearchUpstream(graph, node_input, require_cast, require_type_change, removed_nodes, level);
|
||||
SearchUpstream(graph, node_input, node, require_cast, require_type_change, removed_nodes, level);
|
||||
if (require_cast.find(node_input) == require_cast.end()) {
|
||||
require_type_change.insert(node_input);
|
||||
}
|
||||
|
|
@ -248,7 +271,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg,
|
|||
// be converted from float to float16 along the way.
|
||||
// The recursion only traverses an
|
||||
static void SearchDownstream(Graph& graph, NodeArg* node_arg,
|
||||
std::unordered_set<NodeArg*>& require_cast,
|
||||
NodeArgToConsumerMap& require_cast,
|
||||
std::unordered_set<NodeArg*>& require_type_change,
|
||||
std::deque<NodeIndex>& removed_nodes,
|
||||
size_t level) {
|
||||
|
|
@ -257,21 +280,21 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
|
|||
std::string op_type = node->OpType();
|
||||
if (IsCastTo(node, TensorProto_DataType_FLOAT)) {
|
||||
// This Cast node and the Cast node that will be created later will cancel out
|
||||
require_cast.insert(node_arg);
|
||||
require_cast[node_arg].push_back(node);
|
||||
} else {
|
||||
if (!IsFP16Allow(op_type, level)) {
|
||||
if (node_arg->Exists() &&
|
||||
IsType(*node_arg, TensorProto_DataType_FLOAT)) {
|
||||
require_cast.insert(node_arg);
|
||||
require_cast[node_arg].push_back(node);
|
||||
}
|
||||
} else {
|
||||
// If the node has other float32 inputs then stop the search
|
||||
for (const auto* input_def : node->InputDefs()) {
|
||||
// TODO: If the secified level of the optimization is greater than 1 then
|
||||
// TODO: If the specified level of the optimization is greater than 1 then
|
||||
// convert initializers if any from float to float16.
|
||||
if (input_def != node_arg) {
|
||||
if (IsType(*input_def, TensorProto_DataType_FLOAT)) {
|
||||
require_cast.insert(node_arg);
|
||||
require_cast[node_arg].push_back(node);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -290,8 +313,8 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
|
|||
}
|
||||
}
|
||||
}
|
||||
if (graph.IsOutput(node_arg)) {
|
||||
require_cast.insert(node_arg);
|
||||
if (graph.IsOutput(node_arg) && require_cast.find(node_arg) == require_cast.end()) {
|
||||
require_cast.insert(std::make_pair(node_arg, std::vector<Node*>()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -299,25 +322,31 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
|
|||
// Collects all the names from the pointers of the objects stores in the container class C
|
||||
// the class should have a member functions returning a string (or a ref).
|
||||
template <typename C, typename T = typename C::value_type>
|
||||
static std::string ConcatNames(C const& items) {
|
||||
static std::string ConcatNames(
|
||||
C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) {
|
||||
std::vector<std::string> names;
|
||||
std::transform(items.begin(), items.end(), back_inserter(names), [](T n) { return n->Name(); });
|
||||
std::transform(items.begin(), items.end(), back_inserter(names), f);
|
||||
return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; });
|
||||
}
|
||||
|
||||
// Change the elem_type of the given NodeArgs from FLOAT to FLOAT16.
|
||||
static void ChangeTypeToFP16(Graph& graph, std::unordered_set<NodeArg*>& require_type_change, const logging::Logger& logger) {
|
||||
static void ChangeTypeToFP16(Graph& graph, std::unordered_set<NodeArg*>& require_type_change, bool is_forward, const logging::Logger& logger) {
|
||||
ONNX_NAMESPACE::TypeProto type_proto;
|
||||
type_proto.mutable_tensor_type()->set_elem_type(TensorProto::FLOAT16);
|
||||
for (NodeArg* node_arg : require_type_change) {
|
||||
if (IsType(*node_arg, TensorProto::FLOAT)) {
|
||||
node_arg->UpdateTypeAndShape(type_proto, true, true, logger);
|
||||
for (const Node* node : graph.GetConsumerNodes(node_arg->Name())) {
|
||||
converted_node_names.insert(node->Name());
|
||||
}
|
||||
const Node* producer = graph.GetProducerNode(node_arg->Name());
|
||||
if (nullptr != producer) {
|
||||
converted_node_names.insert(producer->Name());
|
||||
if (is_forward) {
|
||||
// Propagating forwards. Count consumers.
|
||||
for (const Node* node : graph.GetConsumerNodes(node_arg->Name())) {
|
||||
converted_node_names.insert(node->Name());
|
||||
}
|
||||
} else {
|
||||
// Propagating backwards. Count producers.
|
||||
const Node* producer = graph.GetProducerNode(node_arg->Name());
|
||||
if (nullptr != producer) {
|
||||
converted_node_names.insert(producer->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -338,7 +367,7 @@ static bool PropagateForwards(Graph& graph, Node* node,
|
|||
const logging::Logger& logger) {
|
||||
ORT_ENFORCE(nullptr != node);
|
||||
bool modified = false;
|
||||
std::unordered_set<NodeArg*> require_cast;
|
||||
NodeArgToConsumerMap require_cast;
|
||||
std::unordered_set<NodeArg*> require_type_change;
|
||||
NodeArg* cast_output = node->MutableOutputDefs()[0];
|
||||
SearchDownstream(graph, cast_output, require_cast, require_type_change, removed_nodes, level);
|
||||
|
|
@ -347,8 +376,9 @@ static bool PropagateForwards(Graph& graph, Node* node,
|
|||
LOGS(logger, VERBOSE) << "PropagateForwards: Removed Cast node " << node->Name();
|
||||
RemoveCastNodesChain(graph, {node}, removed_nodes);
|
||||
InsertCastNodes(graph, require_cast, false, removed_nodes);
|
||||
ChangeTypeToFP16(graph, require_type_change, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateForwwards: Inserted Cast nodes " << ConcatNames<std::unordered_set<NodeArg*>>(require_cast);
|
||||
ChangeTypeToFP16(graph, require_type_change, true, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateForwwards: Inserted Cast nodes "
|
||||
<< ConcatNames<NodeArgToConsumerMap>(require_cast, GetName);
|
||||
modified = true;
|
||||
}
|
||||
return modified;
|
||||
|
|
@ -370,18 +400,26 @@ static bool PropagateBackwards(Graph& graph, Node* node,
|
|||
const logging::Logger& logger) {
|
||||
bool modified = false;
|
||||
ORT_ENFORCE(nullptr != node);
|
||||
std::unordered_set<NodeArg*> require_cast;
|
||||
NodeArgToConsumerMap require_cast;
|
||||
NodeArg* cast_input = node->MutableInputDefs()[0];
|
||||
const Node* cast_input_producer = graph.GetProducerNode(cast_input->Name()); // nullptr for graph outputs
|
||||
// If the Cast input feeds more than one node or the cast node feeds a graph output and at least one
|
||||
// node then it cannot propagate.
|
||||
size_t consumer_node_count = graph.GetConsumerNodes(cast_input->Name()).size();
|
||||
if (consumer_node_count > 1 ||
|
||||
(nullptr != cast_input_producer && graph.GetNodeOutputsInGraphOutputs(*cast_input_producer).size() > 0 && consumer_node_count > 0)) {
|
||||
return modified;
|
||||
}
|
||||
std::unordered_set<NodeArg*> require_type_change = {cast_input};
|
||||
SearchUpstream(graph, cast_input, require_cast, require_type_change, removed_nodes, level);
|
||||
SearchUpstream(graph, cast_input, node, require_cast, require_type_change, removed_nodes, level);
|
||||
if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end()) {
|
||||
// Remove Cast operation
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name();
|
||||
RemoveCastNodesChain(graph, {node}, removed_nodes);
|
||||
InsertCastNodes(graph, require_cast, true, removed_nodes);
|
||||
ChangeTypeToFP16(graph, require_type_change, logger);
|
||||
ChangeTypeToFP16(graph, require_type_change, false, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted Cast nodes "
|
||||
<< ConcatNames<std::unordered_set<NodeArg*>>(require_cast);
|
||||
<< ConcatNames<NodeArgToConsumerMap>(require_cast, GetName);
|
||||
LOGS(logger, VERBOSE) << "PropagateBackwards: Changed the type from float to float16 : "
|
||||
<< ConcatNames<std::unordered_set<NodeArg*>>(require_type_change);
|
||||
modified = true;
|
||||
|
|
@ -515,16 +553,17 @@ static bool PropagateFP32CastsFromInputsToOutputs(Graph& graph, Node* node,
|
|||
for (Node* cast : casts) {
|
||||
RemoveCastNodesChain(graph, {cast}, removed_nodes);
|
||||
}
|
||||
std::unordered_set<NodeArg*> node_args;
|
||||
NodeArgToConsumerMap node_args_map;
|
||||
for (NodeArg* output : node->MutableOutputDefs()) {
|
||||
if (output->Exists() && IsType(*output, TensorProto::FLOAT)) {
|
||||
node_args.insert(output);
|
||||
node_args_map.insert(std::make_pair(output, graph.GetMutableConsumerNodes(output->Name())));
|
||||
}
|
||||
}
|
||||
InsertCastNodes(graph, node_args, false, removed_nodes);
|
||||
ChangeTypeToFP16(graph, require_type_change, logger);
|
||||
InsertCastNodes(graph, node_args_map, false, removed_nodes);
|
||||
ChangeTypeToFP16(graph, require_type_change, true, logger);
|
||||
|
||||
LOGS(logger, VERBOSE) << "PropagateFP32CastsFromInputsToOutputs: Inserted Cast node to "
|
||||
<< ConcatNames(node_args);
|
||||
<< ConcatNames<NodeArgToConsumerMap>(node_args_map, GetName);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -574,15 +613,16 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
|
|||
for (Node* cast : casts) {
|
||||
RemoveCastNodesChain(graph, {cast}, removed_nodes);
|
||||
}
|
||||
std::unordered_set<NodeArg*> node_args;
|
||||
NodeArgToConsumerMap node_args_map;
|
||||
for (NodeArg* input : node->MutableInputDefs()) {
|
||||
if (IsType(*input, TensorProto::FLOAT)) {
|
||||
node_args.insert(input);
|
||||
node_args_map.insert(std::make_pair(input, std::vector<Node*>({node})));
|
||||
}
|
||||
}
|
||||
InsertCastNodes(graph, node_args, true, removed_nodes);
|
||||
ChangeTypeToFP16(graph, require_type_change, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to " << ConcatNames(node_args);
|
||||
InsertCastNodes(graph, node_args_map, true, removed_nodes);
|
||||
ChangeTypeToFP16(graph, require_type_change, false, logger);
|
||||
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to "
|
||||
<< ConcatNames<NodeArgToConsumerMap>(node_args_map, GetName);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -704,6 +744,7 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
|
|||
// Generate summary if the graph is modified
|
||||
if (modified) {
|
||||
LOGS(logger, INFO) << "Propagate Cast operations summary:";
|
||||
LOGS(logger, INFO) << "Number of passes = " << pass;
|
||||
LOGS(logger, INFO) << "Nodes Inserted:";
|
||||
std::for_each(inserted_node_names.begin(), inserted_node_names.end(), [removed_node_names, logger](std::string name) {
|
||||
if (removed_node_names.find(name) == removed_node_names.end()) { LOGS(logger, INFO) << name; } });
|
||||
|
|
|
|||
|
|
@ -3968,7 +3968,13 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
|
|||
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx", 2, allow_matmul_transpose_add},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2.onnx", 1, allow_matmul_transpose_add},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2_cast_sum.onnx", 4, allow_matmul_transpose_add},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx", 3, allow_matmul_transpose_add}};
|
||||
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx", 3, allow_matmul_transpose_add},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs.onnx", 1, allow_matmul},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast.onnx", 1, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 1, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul.onnx", 2, allow_matmul},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose},
|
||||
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}};
|
||||
|
||||
// Create a temporary directory, which will be deleted automatically, to save/load the transformed models.
|
||||
TemporaryDirectory temp_dir{ORT_TSTR("propagate_casts_test_output_dir")};
|
||||
|
|
|
|||
|
|
@ -119,6 +119,33 @@ def gen_fuse_sibling_casts(model_path):
|
|||
|
||||
def flip_type(flip, type):
|
||||
return (TensorProto.FLOAT16 if type == TensorProto.FLOAT else TensorProto.FLOAT) if flip else type
|
||||
def do_cast_inputs(input_0, input_1, nodes):
|
||||
input_cast_type = TensorProto.FLOAT
|
||||
nodes.extend([helper.make_node(
|
||||
"Cast",
|
||||
[input_0],
|
||||
["cast_"+input_0],
|
||||
"Cast_0",
|
||||
to = input_cast_type),
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
[input_1],
|
||||
["cast_"+input_1],
|
||||
"Cast_1",
|
||||
to = input_cast_type)])
|
||||
return "cast_"+input_0, "cast_"+input_1
|
||||
def do_transpose_inputs(input_0, input_1, nodes):
|
||||
nodes.extend([helper.make_node("Transpose", [input_0], ["transpose_"+input_0], "Transpose_0"),
|
||||
helper.make_node("Transpose", [input_1], ["transpose_"+input_1], "Transpose_1")])
|
||||
return "transpose_"+input_0, "transpose_"+input_1
|
||||
def do_cast_product(product, nodes):
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
[product],
|
||||
["cast" + product],
|
||||
"Cast_2",
|
||||
to = TensorProto.FLOAT16))
|
||||
return "cast_"+product
|
||||
|
||||
def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2):
|
||||
nodes = [
|
||||
|
|
@ -201,6 +228,74 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc
|
|||
("_cast_sum" if cast_sum else ""),
|
||||
nodes, inputs, outputs, [])
|
||||
|
||||
def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul):
|
||||
def do_transpose(output_0, output_1, nodes):
|
||||
nodes.extend([helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"),
|
||||
helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")])
|
||||
output_0 = "transpose_0_"+output_0
|
||||
output_1 ="transpose_1_"+output_1
|
||||
return output_0, output_1
|
||||
input_type = TensorProto.FLOAT
|
||||
input_0 = "input_0"
|
||||
input_1 = "input_1"
|
||||
output = "product"
|
||||
output_0 = "product"
|
||||
output_1 = "product"
|
||||
inputs = [
|
||||
helper.make_tensor_value_info(
|
||||
input_0, input_type, ['M', 'K']),
|
||||
helper.make_tensor_value_info(
|
||||
input_1, input_type, ['K', 'N'])
|
||||
]
|
||||
outputs = []
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"MatMul",
|
||||
[input_0, input_1],
|
||||
[output],
|
||||
"MatMul_0")]
|
||||
if second_matmul:
|
||||
nodes.append(helper.make_node(
|
||||
"MatMul",
|
||||
[input_0, input_1],
|
||||
["second_"+output],
|
||||
"MatMul_1"))
|
||||
outputs.append(helper.make_tensor_value_info(
|
||||
"second_"+output, input_type, ['M', 'N']))
|
||||
|
||||
if transpose and transpose_before_cast:
|
||||
output_0, output_1 = do_transpose(output_0, output_1, nodes)
|
||||
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
[output_0],
|
||||
["cast_0_"+output_0],
|
||||
"Cast_0",
|
||||
to = TensorProto.FLOAT16))
|
||||
output_0 = "cast_0_"+output_0
|
||||
|
||||
if second_matmul:
|
||||
nodes.append(helper.make_node(
|
||||
"Cast",
|
||||
[output_1],
|
||||
["cast_1_"+output_1],
|
||||
"Cast_1",
|
||||
to = TensorProto.FLOAT16))
|
||||
output_1 = "cast_1_"+output_1
|
||||
|
||||
if transpose and not transpose_before_cast:
|
||||
output_0, output_1 = do_transpose(output_0, output_1, nodes)
|
||||
|
||||
outputs.extend([
|
||||
helper.make_tensor_value_info(
|
||||
output_0, flip_type(True, input_type), ['M', 'N']),
|
||||
helper.make_tensor_value_info(
|
||||
output_1, flip_type(second_matmul, input_type), ['M', 'N'])
|
||||
])
|
||||
model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose else ""
|
||||
model_path += "_second_matmul" if second_matmul else ""
|
||||
save(model_path, nodes, inputs, outputs, [])
|
||||
|
||||
for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)):
|
||||
if not insert_add and (cast_sum or cast_input2):
|
||||
continue
|
||||
|
|
@ -209,3 +304,8 @@ for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add,
|
|||
|
||||
gen_fuse_sibling_casts("fuse_sibling_casts")
|
||||
gen_fuse_back2back_casts("fuse_back2back_casts")
|
||||
|
||||
for (transpose, transpose_before_cast, second_matmul) in list(itertools.product([False, True], repeat=3)):
|
||||
if not transpose and transpose_before_cast:
|
||||
continue
|
||||
gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul)
|
||||
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_second_matmul.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_after_cast.onnx
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/propagate_cast/matmul_two_outputs_transpose_before_cast.onnx
vendored
Normal file
Binary file not shown.
Binary file not shown.
|
|
@ -58,7 +58,7 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream,
|
|||
}
|
||||
|
||||
graph.SetInputs(input_args);
|
||||
graph_transformer_config_ = config.graph_transformer_config;
|
||||
logging::LoggingManager::SetDefaultLoggerSeverity(config_.loglevel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -154,7 +154,7 @@ Status OrtModuleGraphBuilder::OptimizeInferenceGraph(std::unordered_set<std::str
|
|||
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
|
||||
auto add_transformers = [&](TransformerLevel level) {
|
||||
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
|
||||
level, x_node_arg_names, graph_transformer_config_, *cpu_execution_provider);
|
||||
level, x_node_arg_names, config_.graph_transformer_config, *cpu_execution_provider);
|
||||
for (auto& entry : transformers_to_register) {
|
||||
graph_transformation_mgr.Register(std::move(entry), level);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,11 @@ struct OrtModuleGraphBuilderConfiguration {
|
|||
bool use_invertible_layernorm_grad = false;
|
||||
bool build_gradient_graph = true;
|
||||
|
||||
// Graph transformer configuration
|
||||
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
|
||||
|
||||
// Log severity
|
||||
logging::Severity loglevel{logging::Severity::kWARNING};
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -113,7 +117,6 @@ class OrtModuleGraphBuilder {
|
|||
|
||||
OrtModuleGraphBuilderConfiguration config_;
|
||||
const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
|
||||
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config_;
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ struct TrainingParameters {
|
|||
bool enable_adasum = false;
|
||||
|
||||
// transformation
|
||||
int propagate_cast_ops_level = 1;
|
||||
int propagate_cast_ops_level = -1;
|
||||
std::vector<std::string> propagate_cast_ops_allow;
|
||||
|
||||
// graph dumping
|
||||
|
|
@ -520,6 +520,14 @@ py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class use
|
|||
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(
|
||||
m, "OrtModuleGraphBuilderConfiguration",
|
||||
R"pbdoc(Configuration information for module graph builder.)pbdoc");
|
||||
|
||||
py::enum_<Severity>(m, "Severity", py::arithmetic(), py::module_local())
|
||||
.value("VERBOSE", logging::Severity::kVERBOSE)
|
||||
.value("INFO", logging::Severity::kINFO)
|
||||
.value("WARNING", logging::Severity::kWARNING)
|
||||
.value("ERROR", logging::Severity::kERROR)
|
||||
.value("FATAL", logging::Severity::kFATAL);
|
||||
|
||||
module_graph_builder_config.def(py::init())
|
||||
.def_readwrite("initializer_names", &OrtModuleGraphBuilderConfiguration::initializer_names)
|
||||
.def_readwrite("initializer_names_to_train", &OrtModuleGraphBuilderConfiguration::initializer_names_to_train)
|
||||
|
|
@ -527,7 +535,8 @@ py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class use
|
|||
.def_readwrite("use_invertible_layernorm_grad",
|
||||
&OrtModuleGraphBuilderConfiguration::use_invertible_layernorm_grad)
|
||||
.def_readwrite("build_gradient_graph", &OrtModuleGraphBuilderConfiguration::build_gradient_graph)
|
||||
.def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config);
|
||||
.def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config)
|
||||
.def_readwrite("loglevel", &OrtModuleGraphBuilderConfiguration::loglevel);
|
||||
|
||||
py::class_<GraphInfo> graph_info(m, "GraphInfo",
|
||||
R"pbdoc(The information of split graphs for frontend.)pbdoc");
|
||||
|
|
|
|||
|
|
@ -273,5 +273,10 @@ class GraphExecutionManager(ABC):
|
|||
grad_builder_config.graph_transformer_config = C.GraphTransformerConfiguration()
|
||||
grad_builder_config.graph_transformer_config.propagate_cast_ops_level = self._propagate_cast_ops_level
|
||||
grad_builder_config.graph_transformer_config.propagate_cast_ops_allow = self._propagate_cast_ops_allow
|
||||
grad_builder_config.loglevel = {_logger.LogLevel.VERBOSE : C.Severity.VERBOSE,
|
||||
_logger.LogLevel.INFO : C.Severity.INFO,
|
||||
_logger.LogLevel.WARNING : C.Severity.WARNING,
|
||||
_logger.LogLevel.ERROR : C.Severity.ERROR,
|
||||
_logger.LogLevel.FATAL : C.Severity.FATAL}.get(self._loglevel, C.Severity.WARNING)
|
||||
self._graph_builder = C.OrtModuleGraphBuilder()
|
||||
self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)
|
||||
|
|
|
|||
Loading…
Reference in a new issue