Transformer layer-wise Recompute (#4526)

* Build Recomputation Graph

* Make topological sort to run FW nodes first

* Pattern match start and end of transformer layer

* Topological sort with Priority

* Add logger to Gradient Graph Builder

* Use Logger

* Introduce Execution Order
This commit is contained in:
Sherlock 2020-09-24 19:56:32 -07:00 committed by GitHub
parent b6e71200eb
commit b03fb82ab7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 1267 additions and 634 deletions

View file

@ -95,12 +95,20 @@ class Node {
/** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */
const std::string& Domain() const noexcept { return domain_; }
/** Gets the Node's exection priority.
@remarks Lower value means higher priority */
int Priority() const noexcept { return priority_; };
/** Sets the execution priority of a node.
@remarks Lower value means higher priority */
void SetPriority(int priority) noexcept;
/** Gets the node description. */
const std::string& Description() const noexcept { return description_; }
/** Gets the Node's Node::Type. */
Node::Type NodeType() const noexcept { return node_type_; }
/** Gets the opset version that the Node's operator was first defined in.
@returns Opset version. If -1 the Node's operator has not been set.
@remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build
@ -507,6 +515,9 @@ class Node {
const ONNX_NAMESPACE::OpSchema* op_ = nullptr;
#endif
// Execution priority, lower value for higher priority
int priority_ = 0;
// set from op_->SinceVersion() or via deserialization when OpSchema is not available
int since_version_ = -1;
@ -850,6 +861,13 @@ class Graph {
const std::function<bool(const Node*, const Node*)>& comp,
const std::function<bool(const Node*, const Node*)>& stop) const;
/** Performs topological sort with Kahn's algorithm on the graph/s.
@param enter Visit function that will be invoked on a node when it is visited.
@param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
*/
void KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const;
/** Gets the map of operator domains to their opset versions. */
const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
return domain_to_version_;

View file

@ -4,6 +4,7 @@
#pragma once
#include "core/graph/graph.h"
#include "core/framework/session_options.h"
namespace onnxruntime {
class Function;
@ -87,7 +88,7 @@ class GraphViewer {
int MaxNodeIndex() const noexcept;
/** Gets the NodeIndex values for the Graph nodes, sorted into topological order. */
const std::vector<NodeIndex>& GetNodesInTopologicalOrder() const;
const std::vector<NodeIndex>& GetNodesInTopologicalOrder(ExecutionOrder order = ExecutionOrder::DEFAULT) const;
/**
Gets the NodeIndex values for the root nodes in the Graph.
@ -144,6 +145,10 @@ class GraphViewer {
// The NodeIndex values of the graph nodes sorted in topological order.
std::vector<NodeIndex> nodes_in_topological_order_;
// The NodeIndex values of the graph nodes sorted in topological order with priority.
std::vector<NodeIndex> nodes_in_topological_order_with_priority_;
// Graph root nodes.
std::vector<NodeIndex> root_nodes_;
};

View file

@ -768,7 +768,7 @@ class PlannerImpl {
}; // namespace onnxruntime
Status PlannerImpl::CreatePlan() {
auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder();
auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder(context_.GetExecutionOrder());
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;

View file

@ -28,22 +28,28 @@ class ISequentialPlannerContext {
// If it returns true, planner won't reuse output tensors
// see PlannerImpl::ComputeReusePlan
virtual bool IsParallelExecutionEnabled() const { return false; }
virtual ExecutionOrder GetExecutionOrder() const { return ExecutionOrder::DEFAULT; }
};
class SequentialPlannerContext : public ISequentialPlannerContext {
public:
SequentialPlannerContext(ExecutionMode execution_mode)
: m_execution_mode(execution_mode) {
SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order)
: execution_mode_(execution_mode),
exection_order_(execution_order) {
}
const ONNX_NAMESPACE::TensorShapeProto* GetShape(const onnxruntime::NodeArg& arg) const override {
return arg.Shape();
}
bool IsParallelExecutionEnabled() const override { return m_execution_mode == ExecutionMode::ORT_PARALLEL; }
bool IsParallelExecutionEnabled() const override { return execution_mode_ == ExecutionMode::ORT_PARALLEL; }
ExecutionOrder GetExecutionOrder() const override { return exection_order_; }
private:
ExecutionMode m_execution_mode = ExecutionMode::ORT_SEQUENTIAL;
ExecutionMode execution_mode_ = ExecutionMode::ORT_SEQUENTIAL;
ExecutionOrder exection_order_ = ExecutionOrder::DEFAULT;
};
class SequentialPlanner {

View file

@ -11,12 +11,25 @@
namespace onnxruntime {
enum class ExecutionOrder {
DEFAULT = 0, // default topological sort
PRIORITY_BASED = 1 // priority-based topological sort
};
enum class FreeDimensionOverrideType {
Invalid = 0,
Denotation = 1,
Name = 2
};
enum class ExecutionPriority : int {
GLOBAL_HIGHT = -100,
LOCAL_HIGH = -10,
DEFAULT = 0,
LOCAL_LOW = 10,
GLOBAL_LOW = 100
};
struct FreeDimensionOverride {
std::string dim_identifier;
FreeDimensionOverrideType dim_identifer_type;
@ -29,6 +42,9 @@ struct FreeDimensionOverride {
struct SessionOptions {
ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL;
// set the execution order of the graph
ExecutionOrder execution_order = ExecutionOrder::DEFAULT;
// enable profiling for this session.
bool enable_profiling = false;

View file

@ -832,7 +832,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
});
}
SequentialPlannerContext context(session_options.execution_mode);
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order);
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
execution_providers_, kernel_create_info_map_,
ort_value_name_idx_map_, context, p_seq_exec_plan_));

View file

@ -11,6 +11,7 @@
#include <iostream>
#include <numeric>
#include <stack>
#include <queue>
#include "gsl/gsl"
#include "core/common/logging/logging.h"
@ -420,6 +421,10 @@ const Node* Node::NodeConstIterator::operator->() const {
return &(operator*());
}
void Node::SetPriority(int priority) noexcept {
priority_ = priority;
}
#if !defined(ORT_MINIMAL_BUILD)
void Node::SetNodeType(Node::Type node_type) noexcept {
@ -677,6 +682,7 @@ void Node::Init(const std::string& name,
definitions_.input_defs = input_args;
definitions_.output_defs = output_args;
domain_ = domain;
priority_ = 0;
if (kOnnxDomainAlias == domain_) {
domain_ = kOnnxDomain;
}
@ -1560,6 +1566,44 @@ void Graph::ReverseDFSFrom(const std::vector<const Node*>& from,
}
}
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const {
std::unordered_map<NodeIndex, size_t> in_degree;
std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
std::vector<NodeIndex> topo_order;
for (auto& node : Nodes()) {
size_t input_edge_count = node.GetInputEdgesCount();
in_degree.insert({node.Index(), input_edge_count});
if (input_edge_count == 0) {
to_visit.push(&node);
}
}
while (!to_visit.empty()) {
const Node* current = to_visit.top();
to_visit.pop();
if (!current) continue;
if (enter) {
enter(current);
}
for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) {
in_degree[node_it->Index()]--;
if (in_degree[node_it->Index()] == 0) {
to_visit.push(&*node_it);
}
}
topo_order.push_back(current->Index());
}
if (NumberOfNodes() != static_cast<int>(topo_order.size())) {
ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle.");
}
}
#if !defined(ORT_MINIMAL_BUILD)
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)

View file

@ -14,15 +14,45 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
return n1->Index() < n2->Index();
}
struct PriorityNodeCompare {
inline bool IsHighPri(const Node* n) const {
static const std::unordered_set<std::string> high_pri_ops = {"Shape", "Size"};
return high_pri_ops.find(n->OpType()) != high_pri_ops.end();
}
// Used for std::priority_queue
// If return false, n1 will be output first
// If return true, n2 will be output first
bool operator()(const Node* n1, const Node* n2) const {
// nodes in global high priorty list will be output first
if (IsHighPri(n1) != IsHighPri(n2)) {
return IsHighPri(n2);
}
// nodes with lower priority value will be output first
if (n1->Priority() != n2->Priority()) {
return n1->Priority() > n2->Priority();
}
// otherwise, nodes with lower index will be output first
return n1->Index() > n2->Index();
}
};
GraphViewer::GraphViewer(const Graph& graph) {
graph_ = &graph;
std::vector<const Node*> leaf_nodes;
for (auto& node : graph_->Nodes()) {
// This is a leaf node (without any output node)
if (node.OutputNodesBegin() == node.OutputNodesEnd()) {
// This is a leaf node (without any output node).
leaf_nodes.push_back(&node);
}
// This is a root node (without any input node)
if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
root_nodes_.push_back(node.Index());
}
}
graph.ReverseDFSFrom(
leaf_nodes,
nullptr,
@ -31,11 +61,11 @@ GraphViewer::GraphViewer(const Graph& graph) {
},
NodeCompare());
for (auto& node : graph_->Nodes()) {
if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
root_nodes_.push_back(node.Index());
}
}
graph.KahnsTopologicalSort(
[this](const Node* n) {
nodes_in_topological_order_with_priority_.push_back(n->Index());
},
PriorityNodeCompare());
}
// Graph name.
@ -92,8 +122,15 @@ int GraphViewer::MaxNodeIndex() const noexcept {
return graph_->MaxNodeIndex();
}
const std::vector<NodeIndex>& GraphViewer::GetNodesInTopologicalOrder() const {
return nodes_in_topological_order_;
const std::vector<NodeIndex>& GraphViewer::GetNodesInTopologicalOrder(ExecutionOrder order) const {
switch (order) {
case ExecutionOrder::DEFAULT:
return nodes_in_topological_order_;
case ExecutionOrder::PRIORITY_BASED:
return nodes_in_topological_order_with_priority_;
default:
ORT_THROW("Invalide ExecutionOrder");
}
}
const std::vector<NodeIndex>& GraphViewer::GetRootNodes() const {

File diff suppressed because it is too large Load diff

View file

@ -1089,6 +1089,8 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc")
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
.def_readwrite("execution_mode", &PySessionOptions::execution_mode,
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
.def_readwrite("execution_order", &SessionOptions::execution_order,
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
.def_property(
"graph_optimization_level",
[](const PySessionOptions* options) -> GraphOptimizationLevel {

View file

@ -231,12 +231,12 @@ TEST_F(GraphTest, SimpleUnique) {
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(std::move(m), model, nullptr, *logger_));
}
TEST_F(GraphTest, UnusedValueInfoSerializes) {
ModelProto m;
m.set_ir_version(4);
ImportOpset(m, "", 11);
GraphProto& g = *m.mutable_graph();
GraphProto& g = *m.mutable_graph();
NodeProto* node = g.add_node();
*node->add_input() = "x";
*node->add_output() = "sum";
@ -633,9 +633,6 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
// node_5 (Merge)
// |
std::unordered_map<std::string, std::pair<std::vector<NodeArg*>, std::vector<NodeArg*>>>
expected_node_name_to_input_output_args;
TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
@ -655,29 +652,24 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
inputs.push_back(&input_arg1);
outputs.push_back(&output_arg1);
expected_node_name_to_input_output_args["node_1"] = {inputs, outputs};
graph.AddNode("node_1", "Identity_Fake", "node 1", inputs, outputs);
inputs[0] = &input_arg2;
outputs[0] = &output_arg2;
expected_node_name_to_input_output_args["node_2"] = {inputs, outputs};
graph.AddNode("node_2", "Identity_Fake", "node 2", inputs, outputs);
inputs[0] = &output_arg2;
outputs[0] = &output_arg3;
expected_node_name_to_input_output_args["node_3"] = {inputs, outputs};
graph.AddNode("node_3", "Identity_Fake", "node 3", inputs, outputs);
inputs[0] = &output_arg1;
outputs[0] = &output_arg4;
expected_node_name_to_input_output_args["node_4"] = {inputs, outputs};
graph.AddNode("node_4", "Identity_Fake", "node 4", inputs, outputs);
inputs.resize(2);
inputs[0] = &output_arg4;
inputs[1] = &output_arg3;
outputs[0] = &output_arg5;
expected_node_name_to_input_output_args["node_5"] = {inputs, outputs};
graph.AddNode("node_5", "Merge_Fake", "node 3", inputs, outputs);
auto status = graph.Resolve();
@ -700,6 +692,223 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
}
}
TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompress) {
Model model("graph_1", false, *logger_);
auto& graph = model.MainGraph();
/*
|
node_0 (Identity)
/ \
node_1 (Identity) compress (pri = LOCAL_HIGH)
| |
node_4 (Identity) decompress (pri = LOCAL_LOW)
\ /
node_5 (Merge)
|
*/
TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32);
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32);
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32);
auto& output_arg3 = graph.GetOrCreateNodeArg("node_3_out_1", &tensor_int32);
auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32);
auto& output_arg5 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32);
graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
auto& compress_node = graph.AddNode("compress", "Identity_Fake", "compress node", {&output_arg0}, {&output_arg2});
compress_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_HIGH));
auto& decompress_node = graph.AddNode("decompress", "Identity_Fake", "decompress node", {&output_arg2}, {&output_arg3});
decompress_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4});
graph.AddNode("node_5", "Merge_Fake", "node 3", {&output_arg4, &output_arg3}, {&output_arg5});
auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
GraphViewer graph_viewer(graph);
// PRIORITY_BASED order
{
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
const std::vector<std::string> expected_priority_based_order =
{"node_0", "compress", "node_1", "node_4", "decompress", "node_5"};
for (size_t i = 0; i < order.size(); ++i) {
auto node = graph.GetNode(order[i]);
EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong.";
}
}
// TOPOLOGICAL order
{
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT);
const std::vector<std::string> expected_topological_order = {
"node_0", "node_1", "node_4", "compress", "decompress", "node_5"};
for (size_t i = 0; i < order.size(); ++i) {
auto node = graph.GetNode(order[i]);
EXPECT_TRUE(node->Name() == expected_topological_order[i]) << "Priority based execution order is wrong.";
}
}
}
TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_Recompute) {
Model model("graph_1", false, *logger_);
auto& graph = model.MainGraph();
/*
|
node_0 (Identity)
/ \
node_1 (Identity) recompute_node_1 (pri = LOCAL_LOW)
| |
node_4 (Identity) |
\ /
node_1_grad (Merge)
|
*/
TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32);
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32);
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32);
auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32);
auto& output_arg5 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32);
graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
auto& recompute_node = graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2});
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4});
graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg4, &output_arg2}, {&output_arg5});
auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
GraphViewer graph_viewer(graph);
// PRIORITY_BASED order
{
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
const std::vector<std::string> expected_priority_based_order =
{"node_0", "node_1", "node_4", "recompute_node_1", "node_1_grad"};
for (size_t i = 0; i < order.size(); ++i) {
auto node = graph.GetNode(order[i]);
EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong.";
}
}
}
TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_MultiLayerRecompute) {
Model model("graph_1", false, *logger_);
auto& graph = model.MainGraph();
/*
|
node_0 (Identity)
/ \
node_1 (Identity) \
| \ \
node_2 (Identity) \ \
| \ \ \
node_3 (Identity) \ \ \
| \ \ \ \
loss (Identity) \ \ \ \
| | \ \ \
1 | | \ \
\ / | \ |
loss_grad recom_node_3 | |
\ / | |
node_3_grad recom_node_2 |
\ / |
node_2_grad recom_node_1
\ /
node_1_grad
|
*/
TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
// FW graph
auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in", &tensor_int32);
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out", &tensor_int32);
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out", &tensor_int32);
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out", &tensor_int32);
auto& output_arg3 = graph.GetOrCreateNodeArg("node_3_out", &tensor_int32);
auto& output_loss = graph.GetOrCreateNodeArg("loss_out", &tensor_int32);
graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
graph.AddNode("node_2", "Identity_Fake", "node 2", {&output_arg1}, {&output_arg2});
graph.AddNode("node_3", "Identity_Fake", "node 3", {&output_arg2}, {&output_arg3});
graph.AddNode("loss", "Identity_Fake", "loss node", {&output_arg3}, {&output_loss});
// Recompute graph
auto& recomputed_arg3 = graph.GetOrCreateNodeArg("node_3_out_recomputed", &tensor_int32);
auto& recomputed_arg2 = graph.GetOrCreateNodeArg("node_2_out_recomputed", &tensor_int32);
auto& recomputed_arg1 = graph.GetOrCreateNodeArg("node_1_out_recomputed", &tensor_int32);
auto& recompute_node3 = graph.AddNode("node_3_recompute", "Identity_Fake", "node 3 recompute", {&output_arg2}, {&recomputed_arg3});
auto& recompute_node2 = graph.AddNode("node_2_recompute", "Identity_Fake", "node 2 recompute", {&output_arg1}, {&recomputed_arg2});
auto& recompute_node1 = graph.AddNode("node_1_recompute", "Identity_Fake", "node 1 recompute", {&output_arg0}, {&recomputed_arg1});
recompute_node3.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
recompute_node2.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
recompute_node1.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
// BW Graph
auto& gradient_start = graph.GetOrCreateNodeArg("gradient_start", &tensor_int32);
auto& loss_grad_output = graph.GetOrCreateNodeArg("loss_grad_output", &tensor_int32);
auto& node_3_grad_output = graph.GetOrCreateNodeArg("node_3_grad_output", &tensor_int32);
auto& node_2_grad_output = graph.GetOrCreateNodeArg("node_2_grad_output", &tensor_int32);
auto& node_1_grad_output = graph.GetOrCreateNodeArg("node_1_grad_output", &tensor_int32);
graph.AddNode("loss_grad", "Merge_Fake", "loss gradient", {&gradient_start, &output_arg3}, {&loss_grad_output});
graph.AddNode("node_3_grad", "Merge_Fake", "node 3 gradient", {&loss_grad_output, &recomputed_arg3}, {&node_3_grad_output});
graph.AddNode("node_2_grad", "Merge_Fake", "node 2 gradient", {&node_3_grad_output, &recomputed_arg2}, {&node_2_grad_output});
graph.AddNode("node_1_grad", "Merge_Fake", "node 1 gradient", {&node_2_grad_output, &recomputed_arg1}, {&node_1_grad_output});
auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
GraphViewer graph_viewer(graph);
// PRIORITY_BASED order
{
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
const std::vector<std::string> expected_priority_based_order = {
"node_0",
"node_1",
"node_2",
"node_3",
"loss",
"loss_grad",
"node_3_recompute",
"node_3_grad",
"node_2_recompute",
"node_2_grad",
"node_1_recompute",
"node_1_grad",
};
for (size_t i = 0; i < order.size(); ++i) {
auto node = graph.GetNode(order[i]);
EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong.";
}
}
}
TEST_F(GraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) {
Model model("graph_1", false, *logger_);
auto& graph = model.MainGraph();

View file

@ -12,7 +12,6 @@
#include "core/optimizer/rule_based_graph_transformer.h"
using namespace ONNX_NAMESPACE;
using namespace std;
namespace onnxruntime {
namespace training {
@ -20,8 +19,8 @@ namespace training {
using namespace common;
GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
const unordered_set<string>& y_node_arg_names,
const unordered_set<string>& x_node_arg_names,
const std::unordered_set<std::string>& y_node_arg_names,
const std::unordered_set<std::string>& x_node_arg_names,
const std::string& loss_node_arg_name,
const GradientGraphConfiguration& gradient_graph_config,
const logging::Logger& logger)
@ -61,6 +60,11 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
y_nodes_.insert(node);
}
reachable_nodes_ = ReverseBFS(y_nodes_);
std::string unreachable_nodes;
// building x_nodes_
for (const auto& name : x_node_arg_names) {
const NodeArg* node_arg = graph->GetNodeArg(name);
if (!node_arg) {
@ -68,21 +72,29 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
}
x_node_args_.insert(node_arg);
vector<const Node*> nodes = graph_->GetConsumerNodes(name);
std::vector<const Node*> nodes = graph_->GetConsumerNodes(name);
if (nodes.empty()) {
ORT_THROW(name, " couldn't find the consumer node.");
}
string grad_arg_name = GradientBuilderBase::GradientName(name);
pending_[grad_arg_name] = static_cast<int>(nodes.size());
std::string grad_arg_name = GradientBuilderBase::GradientName(name);
pending_[grad_arg_name] = 0;
x_nodes_.insert(nodes.begin(), nodes.end());
for (const Node* node : nodes) {
if (IsReachable(node)) {
pending_[grad_arg_name] += 1;
x_nodes_.insert(node);
} else {
unreachable_nodes.append(node->Name() + ", ");
}
}
}
LOGS(logger_, WARNING) << "Following nodes are unreachable for gradient back propagation: " << unreachable_nodes;
}
NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) const {
NodeSet visited(nodes);
deque<const Node*> queue(nodes.begin(), nodes.end());
std::deque<const Node*> queue(nodes.begin(), nodes.end());
while (!queue.empty()) {
const Node* n = queue.front();
@ -106,13 +118,13 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
return visited;
}
Status GradientGraphBuilder::CheckNodeArgsReachable(const NodeSet& reachable_nodes) {
Status GradientGraphBuilder::CheckNodeArgsReachable() const {
for (const NodeArg* node_arg : x_node_args_) {
auto nodes = graph_->GetConsumerNodes(node_arg->Name());
bool reachable = false;
for (const Node* node : nodes) {
if (reachable_nodes.find(node) != reachable_nodes.end()) {
if (IsReachable(node)) {
reachable = true;
break;
}
@ -141,14 +153,13 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
gradient_graph_defs.AddInitializers({tensor_proto});
}
NodeSet reachable_nodes = ReverseBFS(y_nodes_);
ORT_RETURN_IF_ERROR(CheckNodeArgsReachable(reachable_nodes));
ORT_RETURN_IF_ERROR(CheckNodeArgsReachable());
// Going forward to figure out which node_args need backprop-ed.
deque<const Node*> queue(x_nodes_.begin(), x_nodes_.end());
std::deque<const Node*> queue(x_nodes_.begin(), x_nodes_.end());
NodeSet visited(x_nodes_);
unordered_set<const NodeArg*> visited_node_args = x_node_args_;
std::unordered_set<const NodeArg*> visited_node_args = x_node_args_;
visited_node_args.insert(y_node_args_.begin(), y_node_args_.end());
while (!queue.empty()) {
@ -158,7 +169,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
for (auto edge_it = node->OutputEdgesBegin(); edge_it != node->OutputEdgesEnd(); ++edge_it) {
const Node& next_node = edge_it->GetNode();
if (reachable_nodes.find(&next_node) == reachable_nodes.end()) continue;
if (!IsReachable(&next_node)) continue;
auto it = STOP_GRADIENT_EDGES.find(next_node.OpType());
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
@ -168,7 +179,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
}
const NodeArg* node_arg = node->OutputDefs()[edge_it->GetSrcArgIndex()];
string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name());
std::string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name());
pending_[grad_node_arg_name] += 1;
@ -185,7 +196,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
// visited_node_args are the node_args involved
for (auto node : visited) {
//TODO: might not need two sets, the union of them might be enough
unordered_set<string> input_args_need_grad, output_args_need_grad;
std::unordered_set<std::string> input_args_need_grad, output_args_need_grad;
for (auto arg : node->InputDefs()) {
if (visited_node_args.find(arg) != visited_node_args.end()) {
input_args_need_grad.insert(arg->Name());
@ -205,7 +216,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
auto found = pending_.find(arg.name);
if (found != pending_.end() && found->second > 1) {
auto idx = gradients_to_accumulate_[arg].size();
string indexed_arg_name = arg.name + "_" + to_string(idx);
std::string indexed_arg_name = arg.name + "_" + to_string(idx);
gradients_to_accumulate_[arg].push_back(ArgDef(indexed_arg_name, arg.type_proto));
arg.name = indexed_arg_name;

View file

@ -89,6 +89,7 @@ class GradientGraphBuilder {
NodeSet y_nodes_;
NodeSet x_nodes_;
NodeSet reachable_nodes_;
Graph* graph_;
@ -117,14 +118,22 @@ class GradientGraphBuilder {
@param nodes Starting nodes for ReverseBFS
@returns All the nodes visited during ReverseBFS
*/
NodeSet ReverseBFS(const NodeSet& nodes);
NodeSet ReverseBFS(const NodeSet& nodes) const;
/**
Check if 'x_node_args_' are reachable from 'y_node_args_' for computing the partial derivative
@param reachable_nodes All the nodes reachable from the 'y_node_args_'
@returns OK if all 'x_node_args_' are reachable, else an ONNXRUNTIME INVALID_ARGUMENT status
*/
Status CheckNodeArgsReachable(const NodeSet& reachable_nodes);
Status CheckNodeArgsReachable() const;
/**
Check if node is reachable from the 'y_node_args_'
**/
bool IsReachable(const Node* node) const {
return reachable_nodes_.find(node) != reachable_nodes_.end();
}
};
} // namespace training

View file

@ -663,9 +663,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetReshapeGradient) {
}
}
return std::vector<NodeDef>{
NodeDef("ReshapeGrad",
{I(0), GO(0)},
{GI(0)})};
NodeDef("Shape", {I(0)}, {IA("x_shape")}),
NodeDef("Reshape", {GO(0), IA("x_shape")}, {GI(0)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) {

View file

@ -95,6 +95,15 @@ class GradientBuilderBase {
// i-th output of forward op
ArgDef O(const size_t i) const {
ORT_ENFORCE(i < node_->OutputDefs().size());
const std::string& name = node_->OutputDefs()[i]->Name();
const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name));
if (recomputed_nodearg) {
const Node* producer_node = graph_->GetProducerNode(name);
LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name();
return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto());
}
return ArgDef(node_->OutputDefs()[i]->Name(), node_->OutputDefs()[i]->TypeAsProto());
}

View file

@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/optimizer/dropout_recompute.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
namespace onnxruntime {
Node& InsertDropoutRecompute(Graph& graph, Node& node, bool use_original_input) {
NodeArg* input = node.MutableInputDefs()[0];
if (!use_original_input) {
auto& recomputed_input = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(input->Name()),
input->TypeAsProto());
input = &recomputed_input;
}
const auto& output = node.OutputDefs()[0];
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
"DropoutGrad",
"Recompute of " + node.Name(),
{
input, // X
node.MutableOutputDefs()[1], // mask
node.MutableInputDefs()[1], // ratio
node.MutableInputDefs()[2] // training_mode
},
{&recomputed_output},
{},
kMSDomain);
return recompute_node;
}
} // namespace onnxruntime

View file

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/graph.h"
namespace onnxruntime {
Node& InsertDropoutRecompute(Graph& graph, Node& node, bool use_original_input);
} // namespace onnxruntime

View file

@ -43,6 +43,7 @@
#include "orttraining/core/optimizer/localized_recompute.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
#include "orttraining/core/optimizer/transformer_layer_recompute.h"
namespace onnxruntime {
namespace training {
@ -73,10 +74,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
if (config.gelu_checkpoint) {
if (config.gelu_recompute) {
rule_transformer->Register(make_unique<GeluRecompute>());
}
if (config.attn_dropout_checkpoint) {
if (config.attn_dropout_recompute) {
rule_transformer->Register(make_unique<AttentionDropoutRecompute>());
}
@ -104,6 +105,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
horizontal_parallel_size, compatible_eps));
}
transformers.emplace_back(onnxruntime::make_unique<ComputationReductionTransformer>(compatible_eps));
if (config.transformer_layer_recompute) {
transformers.emplace_back(onnxruntime::make_unique<TransformerLayerRecompute>(compatible_eps));
}
} break;
case TransformerLevel::Level2: {

View file

@ -4,6 +4,7 @@
#include "core/graph/graph_utils.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "orttraining/core/optimizer/localized_recompute.h"
#include "orttraining/core/optimizer/dropout_recompute.h"
using namespace ONNX_NAMESPACE;
@ -23,13 +24,15 @@ Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
graph.AddNode(node.Name() + "_recompute",
node.OpType(),
"Recompute of " + node.Name(),
{node.MutableInputDefs()[0]},
{&recomputed_output},
&node.GetAttributes(),
node.Domain());
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
node.OpType(),
"Recompute of " + node.Name(),
{node.MutableInputDefs()[0]},
{&recomputed_output},
&node.GetAttributes(),
node.Domain());
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
return Status::OK();
@ -46,23 +49,8 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const N
}
Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
const auto& output = node.OutputDefs()[0];
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
graph.AddNode(node.Name() + "_recompute",
"DropoutGrad", // Reusing DropoutGrad as the recompute op
"Recompute of " + node.Name(),
{
node.MutableInputDefs()[0], // X
node.MutableOutputDefs()[1], // mask
node.MutableInputDefs()[1], // ratio
node.MutableInputDefs()[2] // training_mode
},
{&recomputed_output},
{},
kMSDomain);
Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true);
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
return Status::OK();

View file

@ -0,0 +1,189 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/optimizer/transformer_layer_recompute.h"
#include "orttraining/core/optimizer/dropout_recompute.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "core/common/common.h"
#include <deque>
namespace onnxruntime {
Status TransformerLayerRecompute::IdentifyTransformerLayerEdges(
const Graph& graph,
std::vector<std::pair<const NodeArg*, const NodeArg*>>& start_end_edges,
const logging::Logger& logger) const {
const std::unordered_set<std::string> gelu_ops{"Gelu", "BiasGelu", "FastGelu"};
const std::unordered_set<std::string> dropout_ops{"Dropout", "BiasDropout"};
const std::unordered_set<std::string> layernorm_ops{"LayerNormalization", "SkipLayerNormalization"};
std::vector<const NodeArg*> layer_start_edges, layer_end_edges;
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto& node = *graph.GetNode(node_index);
// Look for start of a transformer layer
if ((layernorm_ops.find(node.OpType()) != layernorm_ops.end() ||
dropout_ops.find(node.OpType()) != dropout_ops.end()) &&
node.GetOutputEdgesCount() == 4) {
layer_start_edges.push_back(node.OutputDefs()[0]);
}
// Look for end of a transformer layer
if (gelu_ops.find(node.OpType()) != gelu_ops.end()) {
auto next_node = node.OutputNodesBegin();
while (next_node->OutputNodesBegin() != next_node->OutputNodesEnd() &&
dropout_ops.find(next_node->OpType()) == dropout_ops.end()) {
next_node = next_node->OutputNodesBegin();
}
while (next_node->OutputNodesBegin() != next_node->OutputNodesEnd() &&
layernorm_ops.find(next_node->OpType()) == layernorm_ops.end()) {
next_node = next_node->OutputNodesBegin();
}
if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) {
layer_end_edges.push_back(next_node->OutputDefs()[0]);
}
}
}
ORT_RETURN_IF_NOT(layer_start_edges.size() == layer_end_edges.size(),
"Number of start and end edges doesn't match!, #start=", layer_start_edges.size(),
", #end=", layer_end_edges.size());
start_end_edges.clear();
LOGS(logger, INFO) << "Found " << layer_start_edges.size() << " transformer layers.";
for (size_t i = 0; i < layer_start_edges.size(); ++i) {
start_end_edges.push_back({layer_start_edges[i], layer_end_edges[i]});
LOGS(logger, INFO) << "Start edge: " << layer_start_edges[i]->Name() << " End edge: " << layer_end_edges[i]->Name();
}
return Status::OK();
}
namespace {
typedef std::set<const Node*, NodeCompare> NodeSet;
NodeSet BFSFrom(const std::vector<const Node*>& start_nodes, bool reverse) {
NodeSet visited(start_nodes.begin(), start_nodes.end());
std::deque<const Node*> queue(start_nodes.begin(), start_nodes.end());
while (!queue.empty()) {
const Node* n = queue.front();
queue.pop_front();
auto begin = reverse ? n->InputNodesBegin() : n->OutputNodesBegin();
auto end = reverse ? n->InputNodesEnd() : n->OutputNodesEnd();
for (auto node_it = begin; node_it != end; ++node_it) {
const Node& node = *node_it;
if (visited.find(&node) == visited.end()) {
queue.push_back(&node);
visited.insert(&node);
}
}
}
return visited;
}
} // namespace
std::vector<const Node*> TransformerLayerRecompute::NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const {
// Forward BFS from the start node
std::vector<const Node*> start_nodes = graph.GetConsumerNodes(start->Name());
NodeSet fw_visited = BFSFrom(start_nodes, /*reverse*/ false);
// Reverse BFS from the end node
const Node* end_node = graph.GetProducerNode(end->Name());
NodeSet bw_visited = BFSFrom({end_node}, /*reverse*/ true);
// Join fw_visited and bw_visited
std::vector<const Node*> intersect_nodes;
std::set_intersection(fw_visited.begin(), fw_visited.end(),
bw_visited.begin(), bw_visited.end(),
std::back_inserter(intersect_nodes), NodeCompare());
return intersect_nodes;
}
void TransformerLayerRecompute::InsertRecomputeNodes(Graph& graph, const std::vector<const Node*>& nodes, int priority) const {
auto initializers = graph.GetAllInitializedTensors();
for (const Node* n : nodes) {
Node* node = graph.GetNode(n->Index());
// recomputed Dropout need to produce the same output as original dropout
// currently reusing original dropout's mask to achieve this
if (node->OpType() == "Dropout") {
const NodeArg* input = node->InputDefs()[0];
const Node* p_node = graph.GetProducerNode(input->Name());
bool use_original_input =
initializers.find(input->Name()) != initializers.end() ||
std::find(nodes.begin(), nodes.end(), p_node) == nodes.end();
Node& recompute_node = InsertDropoutRecompute(graph, *node, use_original_input);
recompute_node.SetPriority(priority);
continue;
}
// prepare inputs for recompute node
std::vector<NodeArg*> recomputed_inputs;
for (NodeArg* input : node->MutableInputDefs()) {
const Node* p_node = graph.GetProducerNode(input->Name());
// do not duplicate initializers in recompute subgraph
if (initializers.find(input->Name()) != initializers.end() ||
std::find(nodes.begin(), nodes.end(), p_node) == nodes.end()) {
recomputed_inputs.push_back(input);
} else {
auto& recomputed_input = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(input->Name()),
input->TypeAsProto());
recomputed_inputs.push_back(&recomputed_input);
}
}
// prepare ouputs for recompute node
std::vector<NodeArg*> recomputed_outputs;
for (NodeArg* output : node->MutableOutputDefs()) {
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
output->TypeAsProto());
recomputed_outputs.push_back(&recomputed_output);
}
Node& recompute_node = graph.AddNode(node->Name() + "_recompute",
node->OpType(),
"Recompute of " + node->Name(),
recomputed_inputs,
recomputed_outputs,
&node->GetAttributes(),
node->Domain());
recompute_node.SetPriority(priority);
}
return;
}
Status TransformerLayerRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const {
std::vector<std::pair<const NodeArg*, const NodeArg*>> start_end_edges;
Status s = IdentifyTransformerLayerEdges(graph, start_end_edges, logger);
if (!s.IsOK()) {
modified = false;
return Status::OK();
}
// insert recompute nodes expect for the last transformer layer
// latter recompute layers have higher execution priorty
for (size_t i = 0; i < start_end_edges.size() - 1; ++i) {
std::vector<const Node*> nodes = NodesBetweenEdges(graph, start_end_edges[i].first, start_end_edges[i].second);
InsertRecomputeNodes(graph, nodes, static_cast<int>(start_end_edges.size() - i));
}
modified = true;
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
#include "core/graph/graph_utils.h"
namespace onnxruntime {
class TransformerLayerRecompute : public GraphTransformer {
public:
TransformerLayerRecompute(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("TransformerLayerRecompute", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
private:
Status IdentifyTransformerLayerEdges(const Graph& graph,
std::vector<std::pair<const NodeArg*, const NodeArg*>>& start_end_edges,
const logging::Logger& logger) const;
std::vector<const Node*> NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const;
void InsertRecomputeNodes(Graph& graph, const std::vector<const Node*>& nodes, int priority) const;
};
} // namespace onnxruntime

View file

@ -217,11 +217,6 @@ Status TrainingSession::ConfigureForTraining(
config_result.mixed_precision_config_result = mp_result;
}
if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_with_loss_function_path.value(), SaveOption::NO_RELOAD));
}
// We need to get trainable weights to prevent constant folding from them. This works well if trainable weights are passed from config.
// For case we use GetTrainableModelInitializers to get trainable weights such as C++ frontend, it may get more initializers
// than trainable weights here as it's before transformers. So the constant folding may miss some nodes we actually can fold.
@ -239,6 +234,11 @@ Status TrainingSession::ConfigureForTraining(
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config));
if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_with_loss_function_path.value(), SaveOption::NO_RELOAD));
}
// derive actual set of weights to train
std::unordered_set<std::string> weight_names_to_train =
!filtered_config_weight_names_to_train.empty()

View file

@ -194,10 +194,12 @@ class TrainingSession : public InferenceSession {
struct GraphTransformerConfiguration {
// Whether to enable GELU approximation which is faster but produces different results.
bool enable_gelu_approximation{false};
// Enable checkpointing of attention dropout to save memory
bool attn_dropout_checkpoint{false};
// Enable checkpointing of Gelu activation output to save memory
bool gelu_checkpoint{false};
// Enable recompute of attention dropout to save memory
bool attn_dropout_recompute{false};
// Enable recompute of Gelu activation output to save memory
bool gelu_recompute{false};
// Enable recompute of transformer layer ouput to save memory
bool transformer_layer_recompute{false};
};
GraphTransformerConfiguration graph_transformer_config{};

View file

@ -167,9 +167,11 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
cxxopts::value<bool>()->default_value("true"))
("enable_gelu_approximation", "Specify whether to enable GELU approximation.",
cxxopts::value<bool>()->default_value("true"))
("attn_dropout_checkpoint", "Enable checkpointing of attention dropout to save memory.",
("attn_dropout_recompute", "Enable checkpointing of attention dropout to save memory.",
cxxopts::value<bool>()->default_value("false"))
("gelu_checkpoint", "Enable checkpointing of Gelu activation output to save memory.",
("gelu_recompute", "Enable checkpointing of Gelu activation output to save memory.",
cxxopts::value<bool>()->default_value("false"))
("transformer_layer_recompute", "Enable checkpointing of transformer layer output to save memory.",
cxxopts::value<bool>()->default_value("false"))
("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)",
cxxopts::value<bool>()->default_value("false"));
@ -458,8 +460,9 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
}
params.enable_gelu_approximation = flags["enable_gelu_approximation"].as<bool>();
params.attn_dropout_checkpoint = flags["attn_dropout_checkpoint"].as<bool>();
params.gelu_checkpoint = flags["gelu_checkpoint"].as<bool>();
params.attn_dropout_recompute = flags["attn_dropout_recompute"].as<bool>();
params.gelu_recompute = flags["gelu_recompute"].as<bool>();
params.transformer_layer_recompute = flags["transformer_layer_recompute"].as<bool>();
ort_params.log_severity = static_cast<logging::Severity>(flags["ort_log_severity"].as<int>());
ORT_RETURN_IF_NOT(

View file

@ -34,6 +34,7 @@ namespace training {
static std::vector<FreeDimensionOverride> overrides = {};
static SessionOptions SESSION_OPTION = {
ExecutionMode::ORT_SEQUENTIAL, //execution_mode
ExecutionOrder::PRIORITY_BASED, //execution_order
false, //enable_profiling
ORT_TSTR(""), //optimized_model_filepath
true, //enable_mem_pattern
@ -183,8 +184,9 @@ Status TrainingRunner::Initialize() {
{
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration gt_config{};
gt_config.enable_gelu_approximation = params_.enable_gelu_approximation;
gt_config.attn_dropout_checkpoint = params_.attn_dropout_checkpoint;
gt_config.gelu_checkpoint = params_.gelu_checkpoint;
gt_config.attn_dropout_recompute = params_.attn_dropout_recompute;
gt_config.gelu_recompute = params_.gelu_recompute;
gt_config.transformer_layer_recompute = params_.transformer_layer_recompute;
config.graph_transformer_config = gt_config;
}

View file

@ -175,10 +175,11 @@ class TrainingRunner {
// Enable GELU approximation
bool enable_gelu_approximation = false;
// Enable checkpointing of attention dropout to save memory
bool attn_dropout_checkpoint = false;
bool attn_dropout_recompute = false;
// Enable checkpointing of Gelu activation output to save memory
bool gelu_checkpoint = false;
bool gelu_recompute = false;
// Enable checkpointing of transformer layer output to save memory
bool transformer_layer_recompute = false;
// Use invertible layernorm grad
bool use_invertible_layernorm_grad = false;
};