Add logging support to Cast Propagation transformation from python (#7353)

* Fixes needed to PropagateCast transformation.

* Added number of passes to the logs.

* Added logging support to OrtModuleGraphBuilder.

* Added new testcases.

* Added NodeArgToConsumerMap
This commit is contained in:
satyajandhyala 2021-04-19 12:14:30 -07:00 committed by GitHub
parent 6dda1e0681
commit bb1e417da0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 221 additions and 57 deletions

View file

@ -9,6 +9,12 @@
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
// NodeArg to Select consumer node map.
typedef std::unordered_map<NodeArg*, std::vector<Node*>> NodeArgToConsumerMap;
static std::string GetName(const std::pair<const NodeArg*, std::vector<Node*>>& p) {
return p.first->Name();
};
// The collection fp16_allow_ops, specifies for a given propagate_cast_ops level, a vector of node op_types that
// the code is allowed to propage Cast operations cross. The user may specify a custom list of optypes using level 0.
// The opcodes are split into multiple levels. Cast propagation is done based on the level. Level 2 op code
@ -42,14 +48,17 @@ static bool IsType(const NodeArg& node_arg, TensorProto_DataType data_type) {
}
// InsertCastNodes
// Insert a new Cast node after each NodeArg in the require_cast vector. The cast node is FLOAT16 if is_fp16 is True
// Insert a new Cast node after each NodeArg in the require_cast map, feeding the nodes in the vector mapped to
// the NodeArg. The other consumers of the NodeArg will not be changed. The cast node is FLOAT16 if is_fp16 is True
// and FLOAT otherwise. This funtion fixes the graph edges in addition to inserting the cast nodes.
static Status InsertCastNodes(Graph& graph,
const std::unordered_set<NodeArg*>& require_cast,
const NodeArgToConsumerMap& require_cast,
bool is_fp16,
std::deque<NodeIndex>& removed_nodes) {
//Create requirred new Cast nodes.
for (NodeArg* node_arg : require_cast) {
for (std::pair<NodeArg*, std::vector<Node*>> element : require_cast) {
NodeArg* node_arg = element.first;
std::vector<Node*> nodes = element.second;
if (!node_arg->Exists()) {
continue;
}
@ -89,7 +98,7 @@ static Status InsertCastNodes(Graph& graph,
// Update consumers of node_arg to use the output of the cast node
int cast_output_index = optimizer_utils::IndexOfNodeOutput(cast, cast_output);
for (Node* consumer : graph.GetMutableConsumerNodes(node_arg->Name())) {
if (nullptr != consumer &&
if (nullptr != consumer && std::find(nodes.begin(), nodes.end(), consumer) != nodes.end() &&
std::find(removed_nodes.begin(), removed_nodes.end(), consumer->Index()) == removed_nodes.end()) {
auto& consumer_inputs = consumer->MutableInputDefs();
int input_index = optimizer_utils::IndexOfNodeInput(*consumer, *node_arg);
@ -102,6 +111,8 @@ static Status InsertCastNodes(Graph& graph,
}
if (nullptr != producer) {
auto& producer_outputs = producer->MutableOutputDefs();
// The following replacement is necessary in case where the output of the cast node is original
// output of the producer, for example the original output of the producer may be the graph output.
std::replace(producer_outputs.begin(), producer_outputs.end(), &cast_output, &cast_input);
graph.UpdateProducerNode(cast_input.Name(), producer->Index());
int input_index = optimizer_utils::IndexOfNodeInput(cast, cast_input);
@ -125,12 +136,22 @@ static Status RemoveCastNodesChain(Graph& graph, std::vector<Node*> casts, std::
auto consumers = graph.GetMutableConsumerNodes(cast_output->Name());
int output_index = (nullptr != producer) ? optimizer_utils::IndexOfNodeOutput(*producer, *cast_input) : -1;
if (producer) {
int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input);
graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index);
if (consumers.empty()) {
auto& outputs = producer->MutableOutputDefs();
std::replace(outputs.begin(), outputs.end(), cast_input, cast_output);
graph.UpdateProducerNode(cast_output->Name(), producer->Index());
if (graph.IsOutput(cast_output)) {
// cast_output is a graph output. Replace the cast node with an Identity operator unless node
// has other outputs.
if (producer->GetOutputEdgesCount() == 1) {
int input_index = optimizer_utils::IndexOfNodeInput(*lead_cast, *cast_input);
graph.RemoveEdge(producer->Index(), lead_cast->Index(), output_index, input_index);
auto& outputs = producer->MutableOutputDefs();
std::replace(outputs.begin(), outputs.end(), cast_input, cast_output);
graph.UpdateProducerNode(cast_output->Name(), producer->Index());
} else {
(void) graph.AddNode(graph.GenerateNodeName(producer->Name() + "_identity"),
"Identity",
"Created as a place-holder for a graph output",
{cast_input},
{cast_output});
}
}
}
// Update consumer nodes
@ -193,27 +214,29 @@ static bool RemoveBackToBackCasts(Graph& graph, Node* node,
// inorder to move an FP16 Cast operation up the graph.
// Visited float NodeArgs are either in require_cast or require_type_change so that the same
// nodearg is traversed not more than once.
static void SearchUpstream(Graph& graph, NodeArg* node_arg,
std::unordered_set<NodeArg*>& require_cast,
static void SearchUpstream(Graph& graph, NodeArg* node_arg, Node* dst_node,
NodeArgToConsumerMap& require_cast,
std::unordered_set<NodeArg*>& require_type_change,
std::deque<NodeIndex>& removed_nodes,
size_t level) {
Node* node = graph.GetMutableProducerNode(node_arg->Name());
if (node == nullptr) {
if (graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
require_cast[node_arg].push_back(dst_node);
} else if (node == nullptr) {
// The graph inputs don't have the producer nodes
if (IsType(*node_arg, TensorProto_DataType_FLOAT)) {
require_cast.insert(node_arg);
require_cast[node_arg].push_back(dst_node);
}
} else if (std::find(removed_nodes.begin(), removed_nodes.end(), node->Index()) == removed_nodes.end()) {
if (IsCastTo(node, TensorProto_DataType_FLOAT)) {
// This Cast node and the Cast node that will be created later will cancel out
require_cast.insert(node_arg);
require_cast[node_arg].push_back(dst_node);
} else {
std::string op_type = node->OpType();
if (!IsFP16Allow(op_type, level)) {
// Cannot traverse-up beyond this point
if (node_arg->Exists() && IsType(*node_arg, TensorProto_DataType_FLOAT)) {
require_cast.insert(node_arg);
require_cast[node_arg].push_back(dst_node);
}
} else {
// If the node has other float32 output(s) then stop the search.
@ -222,7 +245,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg,
// other output_def and still propagate FP16 cast up the graph.
if (output_def != node_arg) {
if (IsType(*output_def, TensorProto_DataType_FLOAT)) {
require_cast.insert(node_arg);
require_cast[node_arg].push_back(dst_node);
return;
}
}
@ -231,7 +254,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg,
if (IsType(*node_input, TensorProto_DataType_FLOAT) &&
require_cast.find(node_input) == require_cast.end() &&
require_type_change.find(node_input) == require_type_change.end()) {
SearchUpstream(graph, node_input, require_cast, require_type_change, removed_nodes, level);
SearchUpstream(graph, node_input, node, require_cast, require_type_change, removed_nodes, level);
if (require_cast.find(node_input) == require_cast.end()) {
require_type_change.insert(node_input);
}
@ -248,7 +271,7 @@ static void SearchUpstream(Graph& graph, NodeArg* node_arg,
// be converted from float to float16 along the way.
// The recursion only traverses an
static void SearchDownstream(Graph& graph, NodeArg* node_arg,
std::unordered_set<NodeArg*>& require_cast,
NodeArgToConsumerMap& require_cast,
std::unordered_set<NodeArg*>& require_type_change,
std::deque<NodeIndex>& removed_nodes,
size_t level) {
@ -257,21 +280,21 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
std::string op_type = node->OpType();
if (IsCastTo(node, TensorProto_DataType_FLOAT)) {
// This Cast node and the Cast node that will be created later will cancel out
require_cast.insert(node_arg);
require_cast[node_arg].push_back(node);
} else {
if (!IsFP16Allow(op_type, level)) {
if (node_arg->Exists() &&
IsType(*node_arg, TensorProto_DataType_FLOAT)) {
require_cast.insert(node_arg);
require_cast[node_arg].push_back(node);
}
} else {
// If the node has other float32 inputs then stop the search
for (const auto* input_def : node->InputDefs()) {
// TODO: If the secified level of the optimization is greater than 1 then
// TODO: If the specified level of the optimization is greater than 1 then
// convert initializers if any from float to float16.
if (input_def != node_arg) {
if (IsType(*input_def, TensorProto_DataType_FLOAT)) {
require_cast.insert(node_arg);
require_cast[node_arg].push_back(node);
return;
}
}
@ -290,8 +313,8 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
}
}
}
if (graph.IsOutput(node_arg)) {
require_cast.insert(node_arg);
if (graph.IsOutput(node_arg) && require_cast.find(node_arg) == require_cast.end()) {
require_cast.insert(std::make_pair(node_arg, std::vector<Node*>()));
}
}
@ -299,25 +322,31 @@ static void SearchDownstream(Graph& graph, NodeArg* node_arg,
// Collects all the names from the pointers of the objects stores in the container class C
// the class should have a member functions returning a string (or a ref).
template <typename C, typename T = typename C::value_type>
static std::string ConcatNames(C const& items) {
static std::string ConcatNames(
C const& items, std::string (*f)(const T& n) = [](const T& n) { return n->Name(); }) {
std::vector<std::string> names;
std::transform(items.begin(), items.end(), back_inserter(names), [](T n) { return n->Name(); });
std::transform(items.begin(), items.end(), back_inserter(names), f);
return std::accumulate(names.begin(), names.end(), std::string(), [](const std::string& a, const std::string& b) { return a + ", " + b; });
}
// Change the elem_type of the given NodeArgs from FLOAT to FLOAT16.
static void ChangeTypeToFP16(Graph& graph, std::unordered_set<NodeArg*>& require_type_change, const logging::Logger& logger) {
static void ChangeTypeToFP16(Graph& graph, std::unordered_set<NodeArg*>& require_type_change, bool is_forward, const logging::Logger& logger) {
ONNX_NAMESPACE::TypeProto type_proto;
type_proto.mutable_tensor_type()->set_elem_type(TensorProto::FLOAT16);
for (NodeArg* node_arg : require_type_change) {
if (IsType(*node_arg, TensorProto::FLOAT)) {
node_arg->UpdateTypeAndShape(type_proto, true, true, logger);
for (const Node* node : graph.GetConsumerNodes(node_arg->Name())) {
converted_node_names.insert(node->Name());
}
const Node* producer = graph.GetProducerNode(node_arg->Name());
if (nullptr != producer) {
converted_node_names.insert(producer->Name());
if (is_forward) {
// Propagating forwards. Count consumers.
for (const Node* node : graph.GetConsumerNodes(node_arg->Name())) {
converted_node_names.insert(node->Name());
}
} else {
// Propagating backwards. Count producers.
const Node* producer = graph.GetProducerNode(node_arg->Name());
if (nullptr != producer) {
converted_node_names.insert(producer->Name());
}
}
}
}
@ -338,7 +367,7 @@ static bool PropagateForwards(Graph& graph, Node* node,
const logging::Logger& logger) {
ORT_ENFORCE(nullptr != node);
bool modified = false;
std::unordered_set<NodeArg*> require_cast;
NodeArgToConsumerMap require_cast;
std::unordered_set<NodeArg*> require_type_change;
NodeArg* cast_output = node->MutableOutputDefs()[0];
SearchDownstream(graph, cast_output, require_cast, require_type_change, removed_nodes, level);
@ -347,8 +376,9 @@ static bool PropagateForwards(Graph& graph, Node* node,
LOGS(logger, VERBOSE) << "PropagateForwards: Removed Cast node " << node->Name();
RemoveCastNodesChain(graph, {node}, removed_nodes);
InsertCastNodes(graph, require_cast, false, removed_nodes);
ChangeTypeToFP16(graph, require_type_change, logger);
LOGS(logger, VERBOSE) << "PropagateForwwards: Inserted Cast nodes " << ConcatNames<std::unordered_set<NodeArg*>>(require_cast);
ChangeTypeToFP16(graph, require_type_change, true, logger);
LOGS(logger, VERBOSE) << "PropagateForwwards: Inserted Cast nodes "
<< ConcatNames<NodeArgToConsumerMap>(require_cast, GetName);
modified = true;
}
return modified;
@ -370,18 +400,26 @@ static bool PropagateBackwards(Graph& graph, Node* node,
const logging::Logger& logger) {
bool modified = false;
ORT_ENFORCE(nullptr != node);
std::unordered_set<NodeArg*> require_cast;
NodeArgToConsumerMap require_cast;
NodeArg* cast_input = node->MutableInputDefs()[0];
const Node* cast_input_producer = graph.GetProducerNode(cast_input->Name()); // nullptr for graph outputs
// If the Cast input feeds more than one node or the cast node feeds a graph output and at least one
// node then it cannot propagate.
size_t consumer_node_count = graph.GetConsumerNodes(cast_input->Name()).size();
if (consumer_node_count > 1 ||
(nullptr != cast_input_producer && graph.GetNodeOutputsInGraphOutputs(*cast_input_producer).size() > 0 && consumer_node_count > 0)) {
return modified;
}
std::unordered_set<NodeArg*> require_type_change = {cast_input};
SearchUpstream(graph, cast_input, require_cast, require_type_change, removed_nodes, level);
SearchUpstream(graph, cast_input, node, require_cast, require_type_change, removed_nodes, level);
if (require_cast.size() > 0 && require_cast.find(cast_input) == require_cast.end()) {
// Remove Cast operation
LOGS(logger, VERBOSE) << "PropagateBackwards: Removed Cast node " << node->Name();
RemoveCastNodesChain(graph, {node}, removed_nodes);
InsertCastNodes(graph, require_cast, true, removed_nodes);
ChangeTypeToFP16(graph, require_type_change, logger);
ChangeTypeToFP16(graph, require_type_change, false, logger);
LOGS(logger, VERBOSE) << "PropagateBackwards: Inserted Cast nodes "
<< ConcatNames<std::unordered_set<NodeArg*>>(require_cast);
<< ConcatNames<NodeArgToConsumerMap>(require_cast, GetName);
LOGS(logger, VERBOSE) << "PropagateBackwards: Changed the type from float to float16 : "
<< ConcatNames<std::unordered_set<NodeArg*>>(require_type_change);
modified = true;
@ -515,16 +553,17 @@ static bool PropagateFP32CastsFromInputsToOutputs(Graph& graph, Node* node,
for (Node* cast : casts) {
RemoveCastNodesChain(graph, {cast}, removed_nodes);
}
std::unordered_set<NodeArg*> node_args;
NodeArgToConsumerMap node_args_map;
for (NodeArg* output : node->MutableOutputDefs()) {
if (output->Exists() && IsType(*output, TensorProto::FLOAT)) {
node_args.insert(output);
node_args_map.insert(std::make_pair(output, graph.GetMutableConsumerNodes(output->Name())));
}
}
InsertCastNodes(graph, node_args, false, removed_nodes);
ChangeTypeToFP16(graph, require_type_change, logger);
InsertCastNodes(graph, node_args_map, false, removed_nodes);
ChangeTypeToFP16(graph, require_type_change, true, logger);
LOGS(logger, VERBOSE) << "PropagateFP32CastsFromInputsToOutputs: Inserted Cast node to "
<< ConcatNames(node_args);
<< ConcatNames<NodeArgToConsumerMap>(node_args_map, GetName);
modified = true;
}
}
@ -574,15 +613,16 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
for (Node* cast : casts) {
RemoveCastNodesChain(graph, {cast}, removed_nodes);
}
std::unordered_set<NodeArg*> node_args;
NodeArgToConsumerMap node_args_map;
for (NodeArg* input : node->MutableInputDefs()) {
if (IsType(*input, TensorProto::FLOAT)) {
node_args.insert(input);
node_args_map.insert(std::make_pair(input, std::vector<Node*>({node})));
}
}
InsertCastNodes(graph, node_args, true, removed_nodes);
ChangeTypeToFP16(graph, require_type_change, logger);
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to " << ConcatNames(node_args);
InsertCastNodes(graph, node_args_map, true, removed_nodes);
ChangeTypeToFP16(graph, require_type_change, false, logger);
LOGS(logger, VERBOSE) << "PropagateFP16CastsFromOutputsToInputs: Inserted Cast node to "
<< ConcatNames<NodeArgToConsumerMap>(node_args_map, GetName);
modified = true;
}
}
@ -704,6 +744,7 @@ Status PropagateCastOps::ApplyImpl(Graph& graph, bool& modified, int graph_level
// Generate summary if the graph is modified
if (modified) {
LOGS(logger, INFO) << "Propagate Cast operations summary:";
LOGS(logger, INFO) << "Number of passes = " << pass;
LOGS(logger, INFO) << "Nodes Inserted:";
std::for_each(inserted_node_names.begin(), inserted_node_names.end(), [removed_node_names, logger](std::string name) {
if (removed_node_names.find(name) == removed_node_names.end()) { LOGS(logger, INFO) << name; } });

View file

@ -3968,7 +3968,13 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2_cast_sum.onnx", 2, allow_matmul_transpose_add},
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_inputs_cast_product_cast_input2.onnx", 1, allow_matmul_transpose_add},
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2_cast_sum.onnx", 4, allow_matmul_transpose_add},
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx", 3, allow_matmul_transpose_add}};
{MODEL_FOLDER "propagate_cast/matmul_add_transpose_product_cast_product_cast_input2.onnx", 3, allow_matmul_transpose_add},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs.onnx", 1, allow_matmul},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast.onnx", 1, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast.onnx", 1, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_second_matmul.onnx", 2, allow_matmul},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_after_cast_second_matmul.onnx", 2, allow_matmul_transpose},
{MODEL_FOLDER "propagate_cast/matmul_two_outputs_transpose_before_cast_second_matmul.onnx", 2, allow_matmul_transpose}};
// Create a temporary directory, which will be deleted automatically, to save/load the transformed models.
TemporaryDirectory temp_dir{ORT_TSTR("propagate_casts_test_output_dir")};

View file

@ -119,6 +119,33 @@ def gen_fuse_sibling_casts(model_path):
def flip_type(flip, type):
return (TensorProto.FLOAT16 if type == TensorProto.FLOAT else TensorProto.FLOAT) if flip else type
def do_cast_inputs(input_0, input_1, nodes):
input_cast_type = TensorProto.FLOAT
nodes.extend([helper.make_node(
"Cast",
[input_0],
["cast_"+input_0],
"Cast_0",
to = input_cast_type),
helper.make_node(
"Cast",
[input_1],
["cast_"+input_1],
"Cast_1",
to = input_cast_type)])
return "cast_"+input_0, "cast_"+input_1
def do_transpose_inputs(input_0, input_1, nodes):
nodes.extend([helper.make_node("Transpose", [input_0], ["transpose_"+input_0], "Transpose_0"),
helper.make_node("Transpose", [input_1], ["transpose_"+input_1], "Transpose_1")])
return "transpose_"+input_0, "transpose_"+input_1
def do_cast_product(product, nodes):
nodes.append(helper.make_node(
"Cast",
[product],
["cast" + product],
"Cast_2",
to = TensorProto.FLOAT16))
return "cast_"+product
def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2):
nodes = [
@ -201,6 +228,74 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc
("_cast_sum" if cast_sum else ""),
nodes, inputs, outputs, [])
def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul):
def do_transpose(output_0, output_1, nodes):
nodes.extend([helper.make_node("Transpose", [output_0], ["transpose_0_"+output_0], "Transpose_0"),
helper.make_node("Transpose", [output_1], ["transpose_1_"+output_1], "Transpose_1")])
output_0 = "transpose_0_"+output_0
output_1 ="transpose_1_"+output_1
return output_0, output_1
input_type = TensorProto.FLOAT
input_0 = "input_0"
input_1 = "input_1"
output = "product"
output_0 = "product"
output_1 = "product"
inputs = [
helper.make_tensor_value_info(
input_0, input_type, ['M', 'K']),
helper.make_tensor_value_info(
input_1, input_type, ['K', 'N'])
]
outputs = []
nodes = [
helper.make_node(
"MatMul",
[input_0, input_1],
[output],
"MatMul_0")]
if second_matmul:
nodes.append(helper.make_node(
"MatMul",
[input_0, input_1],
["second_"+output],
"MatMul_1"))
outputs.append(helper.make_tensor_value_info(
"second_"+output, input_type, ['M', 'N']))
if transpose and transpose_before_cast:
output_0, output_1 = do_transpose(output_0, output_1, nodes)
nodes.append(helper.make_node(
"Cast",
[output_0],
["cast_0_"+output_0],
"Cast_0",
to = TensorProto.FLOAT16))
output_0 = "cast_0_"+output_0
if second_matmul:
nodes.append(helper.make_node(
"Cast",
[output_1],
["cast_1_"+output_1],
"Cast_1",
to = TensorProto.FLOAT16))
output_1 = "cast_1_"+output_1
if transpose and not transpose_before_cast:
output_0, output_1 = do_transpose(output_0, output_1, nodes)
outputs.extend([
helper.make_tensor_value_info(
output_0, flip_type(True, input_type), ['M', 'N']),
helper.make_tensor_value_info(
output_1, flip_type(second_matmul, input_type), ['M', 'N'])
])
model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose else ""
model_path += "_second_matmul" if second_matmul else ""
save(model_path, nodes, inputs, outputs, [])
for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)):
if not insert_add and (cast_sum or cast_input2):
continue
@ -209,3 +304,8 @@ for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add,
gen_fuse_sibling_casts("fuse_sibling_casts")
gen_fuse_back2back_casts("fuse_back2back_casts")
for (transpose, transpose_before_cast, second_matmul) in list(itertools.product([False, True], repeat=3)):
if not transpose and transpose_before_cast:
continue
gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul)

View file

@ -58,7 +58,7 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream,
}
graph.SetInputs(input_args);
graph_transformer_config_ = config.graph_transformer_config;
logging::LoggingManager::SetDefaultLoggerSeverity(config_.loglevel);
return Status::OK();
}
@ -154,7 +154,7 @@ Status OrtModuleGraphBuilder::OptimizeInferenceGraph(std::unordered_set<std::str
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
auto add_transformers = [&](TransformerLevel level) {
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
level, x_node_arg_names, graph_transformer_config_, *cpu_execution_provider);
level, x_node_arg_names, config_.graph_transformer_config, *cpu_execution_provider);
for (auto& entry : transformers_to_register) {
graph_transformation_mgr.Register(std::move(entry), level);
}

View file

@ -28,7 +28,11 @@ struct OrtModuleGraphBuilderConfiguration {
bool use_invertible_layernorm_grad = false;
bool build_gradient_graph = true;
// Graph transformer configuration
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
// Log severity
logging::Severity loglevel{logging::Severity::kWARNING};
};
/**
@ -113,7 +117,6 @@ class OrtModuleGraphBuilder {
OrtModuleGraphBuilderConfiguration config_;
const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config_;
};
} // namespace training

View file

@ -67,7 +67,7 @@ struct TrainingParameters {
bool enable_adasum = false;
// transformation
int propagate_cast_ops_level = 1;
int propagate_cast_ops_level = -1;
std::vector<std::string> propagate_cast_ops_allow;
// graph dumping
@ -520,6 +520,14 @@ py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class use
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(
m, "OrtModuleGraphBuilderConfiguration",
R"pbdoc(Configuration information for module graph builder.)pbdoc");
py::enum_<Severity>(m, "Severity", py::arithmetic(), py::module_local())
.value("VERBOSE", logging::Severity::kVERBOSE)
.value("INFO", logging::Severity::kINFO)
.value("WARNING", logging::Severity::kWARNING)
.value("ERROR", logging::Severity::kERROR)
.value("FATAL", logging::Severity::kFATAL);
module_graph_builder_config.def(py::init())
.def_readwrite("initializer_names", &OrtModuleGraphBuilderConfiguration::initializer_names)
.def_readwrite("initializer_names_to_train", &OrtModuleGraphBuilderConfiguration::initializer_names_to_train)
@ -527,7 +535,8 @@ py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class use
.def_readwrite("use_invertible_layernorm_grad",
&OrtModuleGraphBuilderConfiguration::use_invertible_layernorm_grad)
.def_readwrite("build_gradient_graph", &OrtModuleGraphBuilderConfiguration::build_gradient_graph)
.def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config);
.def_readwrite("graph_transformer_config", &OrtModuleGraphBuilderConfiguration::graph_transformer_config)
.def_readwrite("loglevel", &OrtModuleGraphBuilderConfiguration::loglevel);
py::class_<GraphInfo> graph_info(m, "GraphInfo",
R"pbdoc(The information of split graphs for frontend.)pbdoc");

View file

@ -273,5 +273,10 @@ class GraphExecutionManager(ABC):
grad_builder_config.graph_transformer_config = C.GraphTransformerConfiguration()
grad_builder_config.graph_transformer_config.propagate_cast_ops_level = self._propagate_cast_ops_level
grad_builder_config.graph_transformer_config.propagate_cast_ops_allow = self._propagate_cast_ops_allow
grad_builder_config.loglevel = {_logger.LogLevel.VERBOSE : C.Severity.VERBOSE,
_logger.LogLevel.INFO : C.Severity.INFO,
_logger.LogLevel.WARNING : C.Severity.WARNING,
_logger.LogLevel.ERROR : C.Severity.ERROR,
_logger.LogLevel.FATAL : C.Severity.FATAL}.get(self._loglevel, C.Severity.WARNING)
self._graph_builder = C.OrtModuleGraphBuilder()
self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)