mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Transformer layer-wise Recompute (#4526)
* Build Recomputation Graph * Make topological sort to run FW nodes first * Pattern match start and end of transformer layer * Topological sort with Priority * Add logger to Gradient Graph Builder * Use Logger * Introduce Execution Order
This commit is contained in:
parent
b6e71200eb
commit
b03fb82ab7
26 changed files with 1267 additions and 634 deletions
|
|
@ -95,12 +95,20 @@ class Node {
|
|||
/** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */
|
||||
const std::string& Domain() const noexcept { return domain_; }
|
||||
|
||||
/** Gets the Node's exection priority.
|
||||
@remarks Lower value means higher priority */
|
||||
int Priority() const noexcept { return priority_; };
|
||||
|
||||
/** Sets the execution priority of a node.
|
||||
@remarks Lower value means higher priority */
|
||||
void SetPriority(int priority) noexcept;
|
||||
|
||||
/** Gets the node description. */
|
||||
const std::string& Description() const noexcept { return description_; }
|
||||
|
||||
/** Gets the Node's Node::Type. */
|
||||
Node::Type NodeType() const noexcept { return node_type_; }
|
||||
|
||||
|
||||
/** Gets the opset version that the Node's operator was first defined in.
|
||||
@returns Opset version. If -1 the Node's operator has not been set.
|
||||
@remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build
|
||||
|
|
@ -507,6 +515,9 @@ class Node {
|
|||
const ONNX_NAMESPACE::OpSchema* op_ = nullptr;
|
||||
#endif
|
||||
|
||||
// Execution priority, lower value for higher priority
|
||||
int priority_ = 0;
|
||||
|
||||
// set from op_->SinceVersion() or via deserialization when OpSchema is not available
|
||||
int since_version_ = -1;
|
||||
|
||||
|
|
@ -850,6 +861,13 @@ class Graph {
|
|||
const std::function<bool(const Node*, const Node*)>& comp,
|
||||
const std::function<bool(const Node*, const Node*)>& stop) const;
|
||||
|
||||
/** Performs topological sort with Kahn's algorithm on the graph/s.
|
||||
@param enter Visit function that will be invoked on a node when it is visited.
|
||||
@param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
|
||||
*/
|
||||
void KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
|
||||
const std::function<bool(const Node*, const Node*)>& comp) const;
|
||||
|
||||
/** Gets the map of operator domains to their opset versions. */
|
||||
const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
|
||||
return domain_to_version_;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/framework/session_options.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class Function;
|
||||
|
|
@ -87,7 +88,7 @@ class GraphViewer {
|
|||
int MaxNodeIndex() const noexcept;
|
||||
|
||||
/** Gets the NodeIndex values for the Graph nodes, sorted into topological order. */
|
||||
const std::vector<NodeIndex>& GetNodesInTopologicalOrder() const;
|
||||
const std::vector<NodeIndex>& GetNodesInTopologicalOrder(ExecutionOrder order = ExecutionOrder::DEFAULT) const;
|
||||
|
||||
/**
|
||||
Gets the NodeIndex values for the root nodes in the Graph.
|
||||
|
|
@ -144,6 +145,10 @@ class GraphViewer {
|
|||
|
||||
// The NodeIndex values of the graph nodes sorted in topological order.
|
||||
std::vector<NodeIndex> nodes_in_topological_order_;
|
||||
|
||||
// The NodeIndex values of the graph nodes sorted in topological order with priority.
|
||||
std::vector<NodeIndex> nodes_in_topological_order_with_priority_;
|
||||
|
||||
// Graph root nodes.
|
||||
std::vector<NodeIndex> root_nodes_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -768,7 +768,7 @@ class PlannerImpl {
|
|||
}; // namespace onnxruntime
|
||||
|
||||
Status PlannerImpl::CreatePlan() {
|
||||
auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder();
|
||||
auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder(context_.GetExecutionOrder());
|
||||
|
||||
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
|
||||
|
||||
|
|
|
|||
|
|
@ -28,22 +28,28 @@ class ISequentialPlannerContext {
|
|||
// If it returns true, planner won't reuse output tensors
|
||||
// see PlannerImpl::ComputeReusePlan
|
||||
virtual bool IsParallelExecutionEnabled() const { return false; }
|
||||
|
||||
virtual ExecutionOrder GetExecutionOrder() const { return ExecutionOrder::DEFAULT; }
|
||||
};
|
||||
|
||||
class SequentialPlannerContext : public ISequentialPlannerContext {
|
||||
public:
|
||||
SequentialPlannerContext(ExecutionMode execution_mode)
|
||||
: m_execution_mode(execution_mode) {
|
||||
SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order)
|
||||
: execution_mode_(execution_mode),
|
||||
exection_order_(execution_order) {
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorShapeProto* GetShape(const onnxruntime::NodeArg& arg) const override {
|
||||
return arg.Shape();
|
||||
}
|
||||
|
||||
bool IsParallelExecutionEnabled() const override { return m_execution_mode == ExecutionMode::ORT_PARALLEL; }
|
||||
bool IsParallelExecutionEnabled() const override { return execution_mode_ == ExecutionMode::ORT_PARALLEL; }
|
||||
|
||||
ExecutionOrder GetExecutionOrder() const override { return exection_order_; }
|
||||
|
||||
private:
|
||||
ExecutionMode m_execution_mode = ExecutionMode::ORT_SEQUENTIAL;
|
||||
ExecutionMode execution_mode_ = ExecutionMode::ORT_SEQUENTIAL;
|
||||
ExecutionOrder exection_order_ = ExecutionOrder::DEFAULT;
|
||||
};
|
||||
|
||||
class SequentialPlanner {
|
||||
|
|
|
|||
|
|
@ -11,12 +11,25 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
enum class ExecutionOrder {
|
||||
DEFAULT = 0, // default topological sort
|
||||
PRIORITY_BASED = 1 // priority-based topological sort
|
||||
};
|
||||
|
||||
enum class FreeDimensionOverrideType {
|
||||
Invalid = 0,
|
||||
Denotation = 1,
|
||||
Name = 2
|
||||
};
|
||||
|
||||
enum class ExecutionPriority : int {
|
||||
GLOBAL_HIGHT = -100,
|
||||
LOCAL_HIGH = -10,
|
||||
DEFAULT = 0,
|
||||
LOCAL_LOW = 10,
|
||||
GLOBAL_LOW = 100
|
||||
};
|
||||
|
||||
struct FreeDimensionOverride {
|
||||
std::string dim_identifier;
|
||||
FreeDimensionOverrideType dim_identifer_type;
|
||||
|
|
@ -29,6 +42,9 @@ struct FreeDimensionOverride {
|
|||
struct SessionOptions {
|
||||
ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL;
|
||||
|
||||
// set the execution order of the graph
|
||||
ExecutionOrder execution_order = ExecutionOrder::DEFAULT;
|
||||
|
||||
// enable profiling for this session.
|
||||
bool enable_profiling = false;
|
||||
|
||||
|
|
|
|||
|
|
@ -832,7 +832,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
});
|
||||
}
|
||||
|
||||
SequentialPlannerContext context(session_options.execution_mode);
|
||||
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order);
|
||||
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
|
||||
execution_providers_, kernel_create_info_map_,
|
||||
ort_value_name_idx_map_, context, p_seq_exec_plan_));
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <stack>
|
||||
#include <queue>
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "core/common/logging/logging.h"
|
||||
|
|
@ -420,6 +421,10 @@ const Node* Node::NodeConstIterator::operator->() const {
|
|||
return &(operator*());
|
||||
}
|
||||
|
||||
void Node::SetPriority(int priority) noexcept {
|
||||
priority_ = priority;
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
void Node::SetNodeType(Node::Type node_type) noexcept {
|
||||
|
|
@ -677,6 +682,7 @@ void Node::Init(const std::string& name,
|
|||
definitions_.input_defs = input_args;
|
||||
definitions_.output_defs = output_args;
|
||||
domain_ = domain;
|
||||
priority_ = 0;
|
||||
if (kOnnxDomainAlias == domain_) {
|
||||
domain_ = kOnnxDomain;
|
||||
}
|
||||
|
|
@ -1560,6 +1566,44 @@ void Graph::ReverseDFSFrom(const std::vector<const Node*>& from,
|
|||
}
|
||||
}
|
||||
|
||||
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
|
||||
const std::function<bool(const Node*, const Node*)>& comp) const {
|
||||
std::unordered_map<NodeIndex, size_t> in_degree;
|
||||
std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
|
||||
std::vector<NodeIndex> topo_order;
|
||||
|
||||
for (auto& node : Nodes()) {
|
||||
size_t input_edge_count = node.GetInputEdgesCount();
|
||||
in_degree.insert({node.Index(), input_edge_count});
|
||||
if (input_edge_count == 0) {
|
||||
to_visit.push(&node);
|
||||
}
|
||||
}
|
||||
|
||||
while (!to_visit.empty()) {
|
||||
const Node* current = to_visit.top();
|
||||
to_visit.pop();
|
||||
|
||||
if (!current) continue;
|
||||
|
||||
if (enter) {
|
||||
enter(current);
|
||||
}
|
||||
|
||||
for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) {
|
||||
in_degree[node_it->Index()]--;
|
||||
|
||||
if (in_degree[node_it->Index()] == 0) {
|
||||
to_visit.push(&*node_it);
|
||||
}
|
||||
}
|
||||
topo_order.push_back(current->Index());
|
||||
}
|
||||
|
||||
if (NumberOfNodes() != static_cast<int>(topo_order.size())) {
|
||||
ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle.");
|
||||
}
|
||||
}
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
|
||||
|
|
|
|||
|
|
@ -14,15 +14,45 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
|
|||
return n1->Index() < n2->Index();
|
||||
}
|
||||
|
||||
struct PriorityNodeCompare {
|
||||
inline bool IsHighPri(const Node* n) const {
|
||||
static const std::unordered_set<std::string> high_pri_ops = {"Shape", "Size"};
|
||||
return high_pri_ops.find(n->OpType()) != high_pri_ops.end();
|
||||
}
|
||||
|
||||
// Used for std::priority_queue
|
||||
// If return false, n1 will be output first
|
||||
// If return true, n2 will be output first
|
||||
bool operator()(const Node* n1, const Node* n2) const {
|
||||
// nodes in global high priorty list will be output first
|
||||
if (IsHighPri(n1) != IsHighPri(n2)) {
|
||||
return IsHighPri(n2);
|
||||
}
|
||||
|
||||
// nodes with lower priority value will be output first
|
||||
if (n1->Priority() != n2->Priority()) {
|
||||
return n1->Priority() > n2->Priority();
|
||||
}
|
||||
|
||||
// otherwise, nodes with lower index will be output first
|
||||
return n1->Index() > n2->Index();
|
||||
}
|
||||
};
|
||||
|
||||
GraphViewer::GraphViewer(const Graph& graph) {
|
||||
graph_ = &graph;
|
||||
std::vector<const Node*> leaf_nodes;
|
||||
for (auto& node : graph_->Nodes()) {
|
||||
// This is a leaf node (without any output node)
|
||||
if (node.OutputNodesBegin() == node.OutputNodesEnd()) {
|
||||
// This is a leaf node (without any output node).
|
||||
leaf_nodes.push_back(&node);
|
||||
}
|
||||
// This is a root node (without any input node)
|
||||
if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
|
||||
root_nodes_.push_back(node.Index());
|
||||
}
|
||||
}
|
||||
|
||||
graph.ReverseDFSFrom(
|
||||
leaf_nodes,
|
||||
nullptr,
|
||||
|
|
@ -31,11 +61,11 @@ GraphViewer::GraphViewer(const Graph& graph) {
|
|||
},
|
||||
NodeCompare());
|
||||
|
||||
for (auto& node : graph_->Nodes()) {
|
||||
if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
|
||||
root_nodes_.push_back(node.Index());
|
||||
}
|
||||
}
|
||||
graph.KahnsTopologicalSort(
|
||||
[this](const Node* n) {
|
||||
nodes_in_topological_order_with_priority_.push_back(n->Index());
|
||||
},
|
||||
PriorityNodeCompare());
|
||||
}
|
||||
|
||||
// Graph name.
|
||||
|
|
@ -92,8 +122,15 @@ int GraphViewer::MaxNodeIndex() const noexcept {
|
|||
return graph_->MaxNodeIndex();
|
||||
}
|
||||
|
||||
const std::vector<NodeIndex>& GraphViewer::GetNodesInTopologicalOrder() const {
|
||||
return nodes_in_topological_order_;
|
||||
const std::vector<NodeIndex>& GraphViewer::GetNodesInTopologicalOrder(ExecutionOrder order) const {
|
||||
switch (order) {
|
||||
case ExecutionOrder::DEFAULT:
|
||||
return nodes_in_topological_order_;
|
||||
case ExecutionOrder::PRIORITY_BASED:
|
||||
return nodes_in_topological_order_with_priority_;
|
||||
default:
|
||||
ORT_THROW("Invalide ExecutionOrder");
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<NodeIndex>& GraphViewer::GetRootNodes() const {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1089,6 +1089,8 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc")
|
|||
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
|
||||
.def_readwrite("execution_mode", &PySessionOptions::execution_mode,
|
||||
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
|
||||
.def_readwrite("execution_order", &SessionOptions::execution_order,
|
||||
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
|
||||
.def_property(
|
||||
"graph_optimization_level",
|
||||
[](const PySessionOptions* options) -> GraphOptimizationLevel {
|
||||
|
|
|
|||
|
|
@ -231,12 +231,12 @@ TEST_F(GraphTest, SimpleUnique) {
|
|||
std::shared_ptr<Model> model;
|
||||
ASSERT_STATUS_OK(Model::Load(std::move(m), model, nullptr, *logger_));
|
||||
}
|
||||
|
||||
|
||||
TEST_F(GraphTest, UnusedValueInfoSerializes) {
|
||||
ModelProto m;
|
||||
m.set_ir_version(4);
|
||||
ImportOpset(m, "", 11);
|
||||
GraphProto& g = *m.mutable_graph();
|
||||
GraphProto& g = *m.mutable_graph();
|
||||
NodeProto* node = g.add_node();
|
||||
*node->add_input() = "x";
|
||||
*node->add_output() = "sum";
|
||||
|
|
@ -633,9 +633,6 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
|
|||
// node_5 (Merge)
|
||||
// |
|
||||
|
||||
std::unordered_map<std::string, std::pair<std::vector<NodeArg*>, std::vector<NodeArg*>>>
|
||||
expected_node_name_to_input_output_args;
|
||||
|
||||
TypeProto tensor_int32;
|
||||
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
|
@ -655,29 +652,24 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
|
|||
|
||||
inputs.push_back(&input_arg1);
|
||||
outputs.push_back(&output_arg1);
|
||||
expected_node_name_to_input_output_args["node_1"] = {inputs, outputs};
|
||||
graph.AddNode("node_1", "Identity_Fake", "node 1", inputs, outputs);
|
||||
|
||||
inputs[0] = &input_arg2;
|
||||
outputs[0] = &output_arg2;
|
||||
expected_node_name_to_input_output_args["node_2"] = {inputs, outputs};
|
||||
graph.AddNode("node_2", "Identity_Fake", "node 2", inputs, outputs);
|
||||
|
||||
inputs[0] = &output_arg2;
|
||||
outputs[0] = &output_arg3;
|
||||
expected_node_name_to_input_output_args["node_3"] = {inputs, outputs};
|
||||
graph.AddNode("node_3", "Identity_Fake", "node 3", inputs, outputs);
|
||||
|
||||
inputs[0] = &output_arg1;
|
||||
outputs[0] = &output_arg4;
|
||||
expected_node_name_to_input_output_args["node_4"] = {inputs, outputs};
|
||||
graph.AddNode("node_4", "Identity_Fake", "node 4", inputs, outputs);
|
||||
|
||||
inputs.resize(2);
|
||||
inputs[0] = &output_arg4;
|
||||
inputs[1] = &output_arg3;
|
||||
outputs[0] = &output_arg5;
|
||||
expected_node_name_to_input_output_args["node_5"] = {inputs, outputs};
|
||||
graph.AddNode("node_5", "Merge_Fake", "node 3", inputs, outputs);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
|
|
@ -700,6 +692,223 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompress) {
|
||||
Model model("graph_1", false, *logger_);
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
/*
|
||||
|
|
||||
node_0 (Identity)
|
||||
/ \
|
||||
node_1 (Identity) compress (pri = LOCAL_HIGH)
|
||||
| |
|
||||
node_4 (Identity) decompress (pri = LOCAL_LOW)
|
||||
\ /
|
||||
node_5 (Merge)
|
||||
|
|
||||
*/
|
||||
|
||||
TypeProto tensor_int32;
|
||||
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32);
|
||||
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32);
|
||||
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
|
||||
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32);
|
||||
auto& output_arg3 = graph.GetOrCreateNodeArg("node_3_out_1", &tensor_int32);
|
||||
auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32);
|
||||
auto& output_arg5 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32);
|
||||
|
||||
graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
|
||||
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
|
||||
|
||||
auto& compress_node = graph.AddNode("compress", "Identity_Fake", "compress node", {&output_arg0}, {&output_arg2});
|
||||
compress_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_HIGH));
|
||||
|
||||
auto& decompress_node = graph.AddNode("decompress", "Identity_Fake", "decompress node", {&output_arg2}, {&output_arg3});
|
||||
decompress_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
|
||||
graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4});
|
||||
graph.AddNode("node_5", "Merge_Fake", "node 3", {&output_arg4, &output_arg3}, {&output_arg5});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
GraphViewer graph_viewer(graph);
|
||||
|
||||
// PRIORITY_BASED order
|
||||
{
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
|
||||
const std::vector<std::string> expected_priority_based_order =
|
||||
{"node_0", "compress", "node_1", "node_4", "decompress", "node_5"};
|
||||
for (size_t i = 0; i < order.size(); ++i) {
|
||||
auto node = graph.GetNode(order[i]);
|
||||
EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong.";
|
||||
}
|
||||
}
|
||||
|
||||
// TOPOLOGICAL order
|
||||
{
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT);
|
||||
const std::vector<std::string> expected_topological_order = {
|
||||
"node_0", "node_1", "node_4", "compress", "decompress", "node_5"};
|
||||
for (size_t i = 0; i < order.size(); ++i) {
|
||||
auto node = graph.GetNode(order[i]);
|
||||
EXPECT_TRUE(node->Name() == expected_topological_order[i]) << "Priority based execution order is wrong.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_Recompute) {
|
||||
Model model("graph_1", false, *logger_);
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
/*
|
||||
|
|
||||
node_0 (Identity)
|
||||
/ \
|
||||
node_1 (Identity) recompute_node_1 (pri = LOCAL_LOW)
|
||||
| |
|
||||
node_4 (Identity) |
|
||||
\ /
|
||||
node_1_grad (Merge)
|
||||
|
|
||||
*/
|
||||
|
||||
TypeProto tensor_int32;
|
||||
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32);
|
||||
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32);
|
||||
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
|
||||
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32);
|
||||
auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32);
|
||||
auto& output_arg5 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32);
|
||||
|
||||
graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
|
||||
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
|
||||
|
||||
auto& recompute_node = graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2});
|
||||
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
|
||||
graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4});
|
||||
graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg4, &output_arg2}, {&output_arg5});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
GraphViewer graph_viewer(graph);
|
||||
|
||||
// PRIORITY_BASED order
|
||||
{
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
|
||||
const std::vector<std::string> expected_priority_based_order =
|
||||
{"node_0", "node_1", "node_4", "recompute_node_1", "node_1_grad"};
|
||||
for (size_t i = 0; i < order.size(); ++i) {
|
||||
auto node = graph.GetNode(order[i]);
|
||||
EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_MultiLayerRecompute) {
|
||||
Model model("graph_1", false, *logger_);
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
/*
|
||||
|
|
||||
node_0 (Identity)
|
||||
/ \
|
||||
node_1 (Identity) \
|
||||
| \ \
|
||||
node_2 (Identity) \ \
|
||||
| \ \ \
|
||||
node_3 (Identity) \ \ \
|
||||
| \ \ \ \
|
||||
loss (Identity) \ \ \ \
|
||||
| | \ \ \
|
||||
1 | | \ \
|
||||
\ / | \ |
|
||||
loss_grad recom_node_3 | |
|
||||
\ / | |
|
||||
node_3_grad recom_node_2 |
|
||||
\ / |
|
||||
node_2_grad recom_node_1
|
||||
\ /
|
||||
node_1_grad
|
||||
|
|
||||
*/
|
||||
|
||||
TypeProto tensor_int32;
|
||||
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
|
||||
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
// FW graph
|
||||
auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in", &tensor_int32);
|
||||
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out", &tensor_int32);
|
||||
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out", &tensor_int32);
|
||||
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out", &tensor_int32);
|
||||
auto& output_arg3 = graph.GetOrCreateNodeArg("node_3_out", &tensor_int32);
|
||||
auto& output_loss = graph.GetOrCreateNodeArg("loss_out", &tensor_int32);
|
||||
|
||||
graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
|
||||
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
|
||||
graph.AddNode("node_2", "Identity_Fake", "node 2", {&output_arg1}, {&output_arg2});
|
||||
graph.AddNode("node_3", "Identity_Fake", "node 3", {&output_arg2}, {&output_arg3});
|
||||
graph.AddNode("loss", "Identity_Fake", "loss node", {&output_arg3}, {&output_loss});
|
||||
|
||||
// Recompute graph
|
||||
auto& recomputed_arg3 = graph.GetOrCreateNodeArg("node_3_out_recomputed", &tensor_int32);
|
||||
auto& recomputed_arg2 = graph.GetOrCreateNodeArg("node_2_out_recomputed", &tensor_int32);
|
||||
auto& recomputed_arg1 = graph.GetOrCreateNodeArg("node_1_out_recomputed", &tensor_int32);
|
||||
|
||||
auto& recompute_node3 = graph.AddNode("node_3_recompute", "Identity_Fake", "node 3 recompute", {&output_arg2}, {&recomputed_arg3});
|
||||
auto& recompute_node2 = graph.AddNode("node_2_recompute", "Identity_Fake", "node 2 recompute", {&output_arg1}, {&recomputed_arg2});
|
||||
auto& recompute_node1 = graph.AddNode("node_1_recompute", "Identity_Fake", "node 1 recompute", {&output_arg0}, {&recomputed_arg1});
|
||||
recompute_node3.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
recompute_node2.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
recompute_node1.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
|
||||
// BW Graph
|
||||
auto& gradient_start = graph.GetOrCreateNodeArg("gradient_start", &tensor_int32);
|
||||
auto& loss_grad_output = graph.GetOrCreateNodeArg("loss_grad_output", &tensor_int32);
|
||||
auto& node_3_grad_output = graph.GetOrCreateNodeArg("node_3_grad_output", &tensor_int32);
|
||||
auto& node_2_grad_output = graph.GetOrCreateNodeArg("node_2_grad_output", &tensor_int32);
|
||||
auto& node_1_grad_output = graph.GetOrCreateNodeArg("node_1_grad_output", &tensor_int32);
|
||||
|
||||
graph.AddNode("loss_grad", "Merge_Fake", "loss gradient", {&gradient_start, &output_arg3}, {&loss_grad_output});
|
||||
graph.AddNode("node_3_grad", "Merge_Fake", "node 3 gradient", {&loss_grad_output, &recomputed_arg3}, {&node_3_grad_output});
|
||||
graph.AddNode("node_2_grad", "Merge_Fake", "node 2 gradient", {&node_3_grad_output, &recomputed_arg2}, {&node_2_grad_output});
|
||||
graph.AddNode("node_1_grad", "Merge_Fake", "node 1 gradient", {&node_2_grad_output, &recomputed_arg1}, {&node_1_grad_output});
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
GraphViewer graph_viewer(graph);
|
||||
|
||||
// PRIORITY_BASED order
|
||||
{
|
||||
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
|
||||
const std::vector<std::string> expected_priority_based_order = {
|
||||
"node_0",
|
||||
"node_1",
|
||||
"node_2",
|
||||
"node_3",
|
||||
"loss",
|
||||
"loss_grad",
|
||||
"node_3_recompute",
|
||||
"node_3_grad",
|
||||
"node_2_recompute",
|
||||
"node_2_grad",
|
||||
"node_1_recompute",
|
||||
"node_1_grad",
|
||||
};
|
||||
for (size_t i = 0; i < order.size(); ++i) {
|
||||
auto node = graph.GetNode(order[i]);
|
||||
EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) {
|
||||
Model model("graph_1", false, *logger_);
|
||||
auto& graph = model.MainGraph();
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace std;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
|
@ -20,8 +19,8 @@ namespace training {
|
|||
using namespace common;
|
||||
|
||||
GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
|
||||
const unordered_set<string>& y_node_arg_names,
|
||||
const unordered_set<string>& x_node_arg_names,
|
||||
const std::unordered_set<std::string>& y_node_arg_names,
|
||||
const std::unordered_set<std::string>& x_node_arg_names,
|
||||
const std::string& loss_node_arg_name,
|
||||
const GradientGraphConfiguration& gradient_graph_config,
|
||||
const logging::Logger& logger)
|
||||
|
|
@ -61,6 +60,11 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
|
|||
y_nodes_.insert(node);
|
||||
}
|
||||
|
||||
reachable_nodes_ = ReverseBFS(y_nodes_);
|
||||
|
||||
std::string unreachable_nodes;
|
||||
|
||||
// building x_nodes_
|
||||
for (const auto& name : x_node_arg_names) {
|
||||
const NodeArg* node_arg = graph->GetNodeArg(name);
|
||||
if (!node_arg) {
|
||||
|
|
@ -68,21 +72,29 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
|
|||
}
|
||||
x_node_args_.insert(node_arg);
|
||||
|
||||
vector<const Node*> nodes = graph_->GetConsumerNodes(name);
|
||||
std::vector<const Node*> nodes = graph_->GetConsumerNodes(name);
|
||||
if (nodes.empty()) {
|
||||
ORT_THROW(name, " couldn't find the consumer node.");
|
||||
}
|
||||
|
||||
string grad_arg_name = GradientBuilderBase::GradientName(name);
|
||||
pending_[grad_arg_name] = static_cast<int>(nodes.size());
|
||||
std::string grad_arg_name = GradientBuilderBase::GradientName(name);
|
||||
pending_[grad_arg_name] = 0;
|
||||
|
||||
x_nodes_.insert(nodes.begin(), nodes.end());
|
||||
for (const Node* node : nodes) {
|
||||
if (IsReachable(node)) {
|
||||
pending_[grad_arg_name] += 1;
|
||||
x_nodes_.insert(node);
|
||||
} else {
|
||||
unreachable_nodes.append(node->Name() + ", ");
|
||||
}
|
||||
}
|
||||
}
|
||||
LOGS(logger_, WARNING) << "Following nodes are unreachable for gradient back propagation: " << unreachable_nodes;
|
||||
}
|
||||
|
||||
NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
|
||||
NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) const {
|
||||
NodeSet visited(nodes);
|
||||
deque<const Node*> queue(nodes.begin(), nodes.end());
|
||||
std::deque<const Node*> queue(nodes.begin(), nodes.end());
|
||||
|
||||
while (!queue.empty()) {
|
||||
const Node* n = queue.front();
|
||||
|
|
@ -106,13 +118,13 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) {
|
|||
return visited;
|
||||
}
|
||||
|
||||
Status GradientGraphBuilder::CheckNodeArgsReachable(const NodeSet& reachable_nodes) {
|
||||
Status GradientGraphBuilder::CheckNodeArgsReachable() const {
|
||||
for (const NodeArg* node_arg : x_node_args_) {
|
||||
auto nodes = graph_->GetConsumerNodes(node_arg->Name());
|
||||
|
||||
bool reachable = false;
|
||||
for (const Node* node : nodes) {
|
||||
if (reachable_nodes.find(node) != reachable_nodes.end()) {
|
||||
if (IsReachable(node)) {
|
||||
reachable = true;
|
||||
break;
|
||||
}
|
||||
|
|
@ -141,14 +153,13 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
|
|||
gradient_graph_defs.AddInitializers({tensor_proto});
|
||||
}
|
||||
|
||||
NodeSet reachable_nodes = ReverseBFS(y_nodes_);
|
||||
|
||||
ORT_RETURN_IF_ERROR(CheckNodeArgsReachable(reachable_nodes));
|
||||
ORT_RETURN_IF_ERROR(CheckNodeArgsReachable());
|
||||
|
||||
// Going forward to figure out which node_args need backprop-ed.
|
||||
deque<const Node*> queue(x_nodes_.begin(), x_nodes_.end());
|
||||
std::deque<const Node*> queue(x_nodes_.begin(), x_nodes_.end());
|
||||
NodeSet visited(x_nodes_);
|
||||
unordered_set<const NodeArg*> visited_node_args = x_node_args_;
|
||||
|
||||
std::unordered_set<const NodeArg*> visited_node_args = x_node_args_;
|
||||
visited_node_args.insert(y_node_args_.begin(), y_node_args_.end());
|
||||
|
||||
while (!queue.empty()) {
|
||||
|
|
@ -158,7 +169,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
|
|||
for (auto edge_it = node->OutputEdgesBegin(); edge_it != node->OutputEdgesEnd(); ++edge_it) {
|
||||
const Node& next_node = edge_it->GetNode();
|
||||
|
||||
if (reachable_nodes.find(&next_node) == reachable_nodes.end()) continue;
|
||||
if (!IsReachable(&next_node)) continue;
|
||||
|
||||
auto it = STOP_GRADIENT_EDGES.find(next_node.OpType());
|
||||
if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) {
|
||||
|
|
@ -168,7 +179,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
|
|||
}
|
||||
|
||||
const NodeArg* node_arg = node->OutputDefs()[edge_it->GetSrcArgIndex()];
|
||||
string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name());
|
||||
std::string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name());
|
||||
|
||||
pending_[grad_node_arg_name] += 1;
|
||||
|
||||
|
|
@ -185,7 +196,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
|
|||
// visited_node_args are the node_args involved
|
||||
for (auto node : visited) {
|
||||
//TODO: might not need two sets, the union of them might be enough
|
||||
unordered_set<string> input_args_need_grad, output_args_need_grad;
|
||||
std::unordered_set<std::string> input_args_need_grad, output_args_need_grad;
|
||||
for (auto arg : node->InputDefs()) {
|
||||
if (visited_node_args.find(arg) != visited_node_args.end()) {
|
||||
input_args_need_grad.insert(arg->Name());
|
||||
|
|
@ -205,7 +216,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set<std::string>* p_init
|
|||
auto found = pending_.find(arg.name);
|
||||
if (found != pending_.end() && found->second > 1) {
|
||||
auto idx = gradients_to_accumulate_[arg].size();
|
||||
string indexed_arg_name = arg.name + "_" + to_string(idx);
|
||||
std::string indexed_arg_name = arg.name + "_" + to_string(idx);
|
||||
gradients_to_accumulate_[arg].push_back(ArgDef(indexed_arg_name, arg.type_proto));
|
||||
|
||||
arg.name = indexed_arg_name;
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ class GradientGraphBuilder {
|
|||
|
||||
NodeSet y_nodes_;
|
||||
NodeSet x_nodes_;
|
||||
NodeSet reachable_nodes_;
|
||||
|
||||
Graph* graph_;
|
||||
|
||||
|
|
@ -117,14 +118,22 @@ class GradientGraphBuilder {
|
|||
@param nodes Starting nodes for ReverseBFS
|
||||
@returns All the nodes visited during ReverseBFS
|
||||
*/
|
||||
NodeSet ReverseBFS(const NodeSet& nodes);
|
||||
NodeSet ReverseBFS(const NodeSet& nodes) const;
|
||||
|
||||
/**
|
||||
Check if 'x_node_args_' are reachable from 'y_node_args_' for computing the partial derivative
|
||||
@param reachable_nodes All the nodes reachable from the 'y_node_args_'
|
||||
@returns OK if all 'x_node_args_' are reachable, else an ONNXRUNTIME INVALID_ARGUMENT status
|
||||
*/
|
||||
Status CheckNodeArgsReachable(const NodeSet& reachable_nodes);
|
||||
|
||||
Status CheckNodeArgsReachable() const;
|
||||
|
||||
/**
|
||||
Check if node is reachable from the 'y_node_args_'
|
||||
**/
|
||||
bool IsReachable(const Node* node) const {
|
||||
return reachable_nodes_.find(node) != reachable_nodes_.end();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -663,9 +663,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetReshapeGradient) {
|
|||
}
|
||||
}
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef("ReshapeGrad",
|
||||
{I(0), GO(0)},
|
||||
{GI(0)})};
|
||||
NodeDef("Shape", {I(0)}, {IA("x_shape")}),
|
||||
NodeDef("Reshape", {GO(0), IA("x_shape")}, {GI(0)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) {
|
||||
|
|
|
|||
|
|
@ -95,6 +95,15 @@ class GradientBuilderBase {
|
|||
// i-th output of forward op
|
||||
ArgDef O(const size_t i) const {
|
||||
ORT_ENFORCE(i < node_->OutputDefs().size());
|
||||
|
||||
const std::string& name = node_->OutputDefs()[i]->Name();
|
||||
const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name));
|
||||
if (recomputed_nodearg) {
|
||||
const Node* producer_node = graph_->GetProducerNode(name);
|
||||
LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name();
|
||||
return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto());
|
||||
}
|
||||
|
||||
return ArgDef(node_->OutputDefs()[i]->Name(), node_->OutputDefs()[i]->TypeAsProto());
|
||||
}
|
||||
|
||||
|
|
|
|||
38
orttraining/orttraining/core/optimizer/dropout_recompute.cc
Normal file
38
orttraining/orttraining/core/optimizer/dropout_recompute.cc
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/optimizer/dropout_recompute.h"
|
||||
#include "orttraining/core/graph/recompute_graph_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Node& InsertDropoutRecompute(Graph& graph, Node& node, bool use_original_input) {
|
||||
NodeArg* input = node.MutableInputDefs()[0];
|
||||
if (!use_original_input) {
|
||||
auto& recomputed_input = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(input->Name()),
|
||||
input->TypeAsProto());
|
||||
input = &recomputed_input;
|
||||
}
|
||||
|
||||
const auto& output = node.OutputDefs()[0];
|
||||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
|
||||
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
|
||||
"DropoutGrad",
|
||||
"Recompute of " + node.Name(),
|
||||
{
|
||||
input, // X
|
||||
node.MutableOutputDefs()[1], // mask
|
||||
node.MutableInputDefs()[1], // ratio
|
||||
node.MutableInputDefs()[2] // training_mode
|
||||
|
||||
},
|
||||
{&recomputed_output},
|
||||
{},
|
||||
kMSDomain);
|
||||
|
||||
return recompute_node;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
12
orttraining/orttraining/core/optimizer/dropout_recompute.h
Normal file
12
orttraining/orttraining/core/optimizer/dropout_recompute.h
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/graph/graph.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Node& InsertDropoutRecompute(Graph& graph, Node& node, bool use_original_input);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -43,6 +43,7 @@
|
|||
#include "orttraining/core/optimizer/localized_recompute.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
|
||||
#include "orttraining/core/optimizer/transformer_layer_recompute.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
|
@ -73,10 +74,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
rule_transformer->Register(make_unique<CastElimination>());
|
||||
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
|
||||
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
|
||||
if (config.gelu_checkpoint) {
|
||||
if (config.gelu_recompute) {
|
||||
rule_transformer->Register(make_unique<GeluRecompute>());
|
||||
}
|
||||
if (config.attn_dropout_checkpoint) {
|
||||
if (config.attn_dropout_recompute) {
|
||||
rule_transformer->Register(make_unique<AttentionDropoutRecompute>());
|
||||
}
|
||||
|
||||
|
|
@ -104,6 +105,10 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
horizontal_parallel_size, compatible_eps));
|
||||
}
|
||||
transformers.emplace_back(onnxruntime::make_unique<ComputationReductionTransformer>(compatible_eps));
|
||||
|
||||
if (config.transformer_layer_recompute) {
|
||||
transformers.emplace_back(onnxruntime::make_unique<TransformerLayerRecompute>(compatible_eps));
|
||||
}
|
||||
} break;
|
||||
|
||||
case TransformerLevel::Level2: {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/graph/graph_utils.h"
|
||||
#include "orttraining/core/graph/recompute_graph_utils.h"
|
||||
#include "orttraining/core/optimizer/localized_recompute.h"
|
||||
#include "orttraining/core/optimizer/dropout_recompute.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
|
|
@ -23,13 +24,15 @@ Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
|
|||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
|
||||
graph.AddNode(node.Name() + "_recompute",
|
||||
node.OpType(),
|
||||
"Recompute of " + node.Name(),
|
||||
{node.MutableInputDefs()[0]},
|
||||
{&recomputed_output},
|
||||
&node.GetAttributes(),
|
||||
node.Domain());
|
||||
Node& recompute_node = graph.AddNode(node.Name() + "_recompute",
|
||||
node.OpType(),
|
||||
"Recompute of " + node.Name(),
|
||||
{node.MutableInputDefs()[0]},
|
||||
{&recomputed_output},
|
||||
&node.GetAttributes(),
|
||||
node.Domain());
|
||||
|
||||
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
return Status::OK();
|
||||
|
|
@ -46,23 +49,8 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const N
|
|||
}
|
||||
|
||||
Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
|
||||
const auto& output = node.OutputDefs()[0];
|
||||
|
||||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
|
||||
graph.AddNode(node.Name() + "_recompute",
|
||||
"DropoutGrad", // Reusing DropoutGrad as the recompute op
|
||||
"Recompute of " + node.Name(),
|
||||
{
|
||||
node.MutableInputDefs()[0], // X
|
||||
node.MutableOutputDefs()[1], // mask
|
||||
node.MutableInputDefs()[1], // ratio
|
||||
node.MutableInputDefs()[2] // training_mode
|
||||
},
|
||||
{&recomputed_output},
|
||||
{},
|
||||
kMSDomain);
|
||||
Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true);
|
||||
recompute_node.SetPriority(static_cast<int>(ExecutionPriority::LOCAL_LOW));
|
||||
|
||||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -0,0 +1,189 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/optimizer/transformer_layer_recompute.h"
|
||||
#include "orttraining/core/optimizer/dropout_recompute.h"
|
||||
#include "orttraining/core/graph/recompute_graph_utils.h"
|
||||
#include "core/common/common.h"
|
||||
|
||||
#include <deque>
|
||||
|
||||
namespace onnxruntime {
|
||||
Status TransformerLayerRecompute::IdentifyTransformerLayerEdges(
|
||||
const Graph& graph,
|
||||
std::vector<std::pair<const NodeArg*, const NodeArg*>>& start_end_edges,
|
||||
const logging::Logger& logger) const {
|
||||
const std::unordered_set<std::string> gelu_ops{"Gelu", "BiasGelu", "FastGelu"};
|
||||
const std::unordered_set<std::string> dropout_ops{"Dropout", "BiasDropout"};
|
||||
const std::unordered_set<std::string> layernorm_ops{"LayerNormalization", "SkipLayerNormalization"};
|
||||
|
||||
std::vector<const NodeArg*> layer_start_edges, layer_end_edges;
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
|
||||
// Look for start of a transformer layer
|
||||
if ((layernorm_ops.find(node.OpType()) != layernorm_ops.end() ||
|
||||
dropout_ops.find(node.OpType()) != dropout_ops.end()) &&
|
||||
node.GetOutputEdgesCount() == 4) {
|
||||
layer_start_edges.push_back(node.OutputDefs()[0]);
|
||||
}
|
||||
|
||||
// Look for end of a transformer layer
|
||||
if (gelu_ops.find(node.OpType()) != gelu_ops.end()) {
|
||||
auto next_node = node.OutputNodesBegin();
|
||||
|
||||
while (next_node->OutputNodesBegin() != next_node->OutputNodesEnd() &&
|
||||
dropout_ops.find(next_node->OpType()) == dropout_ops.end()) {
|
||||
next_node = next_node->OutputNodesBegin();
|
||||
}
|
||||
|
||||
while (next_node->OutputNodesBegin() != next_node->OutputNodesEnd() &&
|
||||
layernorm_ops.find(next_node->OpType()) == layernorm_ops.end()) {
|
||||
next_node = next_node->OutputNodesBegin();
|
||||
}
|
||||
|
||||
if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) {
|
||||
layer_end_edges.push_back(next_node->OutputDefs()[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_NOT(layer_start_edges.size() == layer_end_edges.size(),
|
||||
"Number of start and end edges doesn't match!, #start=", layer_start_edges.size(),
|
||||
", #end=", layer_end_edges.size());
|
||||
|
||||
start_end_edges.clear();
|
||||
|
||||
LOGS(logger, INFO) << "Found " << layer_start_edges.size() << " transformer layers.";
|
||||
for (size_t i = 0; i < layer_start_edges.size(); ++i) {
|
||||
start_end_edges.push_back({layer_start_edges[i], layer_end_edges[i]});
|
||||
LOGS(logger, INFO) << "Start edge: " << layer_start_edges[i]->Name() << " End edge: " << layer_end_edges[i]->Name();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::set<const Node*, NodeCompare> NodeSet;
|
||||
|
||||
NodeSet BFSFrom(const std::vector<const Node*>& start_nodes, bool reverse) {
|
||||
NodeSet visited(start_nodes.begin(), start_nodes.end());
|
||||
std::deque<const Node*> queue(start_nodes.begin(), start_nodes.end());
|
||||
while (!queue.empty()) {
|
||||
const Node* n = queue.front();
|
||||
queue.pop_front();
|
||||
|
||||
auto begin = reverse ? n->InputNodesBegin() : n->OutputNodesBegin();
|
||||
auto end = reverse ? n->InputNodesEnd() : n->OutputNodesEnd();
|
||||
|
||||
for (auto node_it = begin; node_it != end; ++node_it) {
|
||||
const Node& node = *node_it;
|
||||
if (visited.find(&node) == visited.end()) {
|
||||
queue.push_back(&node);
|
||||
visited.insert(&node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return visited;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<const Node*> TransformerLayerRecompute::NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const {
|
||||
// Forward BFS from the start node
|
||||
std::vector<const Node*> start_nodes = graph.GetConsumerNodes(start->Name());
|
||||
NodeSet fw_visited = BFSFrom(start_nodes, /*reverse*/ false);
|
||||
|
||||
// Reverse BFS from the end node
|
||||
const Node* end_node = graph.GetProducerNode(end->Name());
|
||||
NodeSet bw_visited = BFSFrom({end_node}, /*reverse*/ true);
|
||||
|
||||
// Join fw_visited and bw_visited
|
||||
std::vector<const Node*> intersect_nodes;
|
||||
std::set_intersection(fw_visited.begin(), fw_visited.end(),
|
||||
bw_visited.begin(), bw_visited.end(),
|
||||
std::back_inserter(intersect_nodes), NodeCompare());
|
||||
|
||||
return intersect_nodes;
|
||||
}
|
||||
|
||||
void TransformerLayerRecompute::InsertRecomputeNodes(Graph& graph, const std::vector<const Node*>& nodes, int priority) const {
|
||||
auto initializers = graph.GetAllInitializedTensors();
|
||||
|
||||
for (const Node* n : nodes) {
|
||||
Node* node = graph.GetNode(n->Index());
|
||||
|
||||
// recomputed Dropout need to produce the same output as original dropout
|
||||
// currently reusing original dropout's mask to achieve this
|
||||
if (node->OpType() == "Dropout") {
|
||||
const NodeArg* input = node->InputDefs()[0];
|
||||
const Node* p_node = graph.GetProducerNode(input->Name());
|
||||
|
||||
bool use_original_input =
|
||||
initializers.find(input->Name()) != initializers.end() ||
|
||||
std::find(nodes.begin(), nodes.end(), p_node) == nodes.end();
|
||||
|
||||
Node& recompute_node = InsertDropoutRecompute(graph, *node, use_original_input);
|
||||
recompute_node.SetPriority(priority);
|
||||
continue;
|
||||
}
|
||||
|
||||
// prepare inputs for recompute node
|
||||
std::vector<NodeArg*> recomputed_inputs;
|
||||
for (NodeArg* input : node->MutableInputDefs()) {
|
||||
const Node* p_node = graph.GetProducerNode(input->Name());
|
||||
|
||||
// do not duplicate initializers in recompute subgraph
|
||||
if (initializers.find(input->Name()) != initializers.end() ||
|
||||
std::find(nodes.begin(), nodes.end(), p_node) == nodes.end()) {
|
||||
recomputed_inputs.push_back(input);
|
||||
} else {
|
||||
auto& recomputed_input = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(input->Name()),
|
||||
input->TypeAsProto());
|
||||
recomputed_inputs.push_back(&recomputed_input);
|
||||
}
|
||||
}
|
||||
|
||||
// prepare ouputs for recompute node
|
||||
std::vector<NodeArg*> recomputed_outputs;
|
||||
for (NodeArg* output : node->MutableOutputDefs()) {
|
||||
auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()),
|
||||
output->TypeAsProto());
|
||||
recomputed_outputs.push_back(&recomputed_output);
|
||||
}
|
||||
|
||||
Node& recompute_node = graph.AddNode(node->Name() + "_recompute",
|
||||
node->OpType(),
|
||||
"Recompute of " + node->Name(),
|
||||
recomputed_inputs,
|
||||
recomputed_outputs,
|
||||
&node->GetAttributes(),
|
||||
node->Domain());
|
||||
recompute_node.SetPriority(priority);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
Status TransformerLayerRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const {
|
||||
std::vector<std::pair<const NodeArg*, const NodeArg*>> start_end_edges;
|
||||
|
||||
Status s = IdentifyTransformerLayerEdges(graph, start_end_edges, logger);
|
||||
if (!s.IsOK()) {
|
||||
modified = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// insert recompute nodes expect for the last transformer layer
|
||||
// latter recompute layers have higher execution priorty
|
||||
for (size_t i = 0; i < start_end_edges.size() - 1; ++i) {
|
||||
std::vector<const Node*> nodes = NodesBetweenEdges(graph, start_end_edges[i].first, start_end_edges[i].second);
|
||||
InsertRecomputeNodes(graph, nodes, static_cast<int>(start_end_edges.size() - i));
|
||||
}
|
||||
|
||||
modified = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class TransformerLayerRecompute : public GraphTransformer {
|
||||
public:
|
||||
TransformerLayerRecompute(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("TransformerLayerRecompute", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
private:
|
||||
Status IdentifyTransformerLayerEdges(const Graph& graph,
|
||||
std::vector<std::pair<const NodeArg*, const NodeArg*>>& start_end_edges,
|
||||
const logging::Logger& logger) const;
|
||||
|
||||
std::vector<const Node*> NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const;
|
||||
|
||||
void InsertRecomputeNodes(Graph& graph, const std::vector<const Node*>& nodes, int priority) const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -217,11 +217,6 @@ Status TrainingSession::ConfigureForTraining(
|
|||
config_result.mixed_precision_config_result = mp_result;
|
||||
}
|
||||
|
||||
if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) {
|
||||
ORT_IGNORE_RETURN_VALUE(Save(
|
||||
config.model_with_loss_function_path.value(), SaveOption::NO_RELOAD));
|
||||
}
|
||||
|
||||
// We need to get trainable weights to prevent constant folding from them. This works well if trainable weights are passed from config.
|
||||
// For case we use GetTrainableModelInitializers to get trainable weights such as C++ frontend, it may get more initializers
|
||||
// than trainable weights here as it's before transformers. So the constant folding may miss some nodes we actually can fold.
|
||||
|
|
@ -239,6 +234,11 @@ Status TrainingSession::ConfigureForTraining(
|
|||
|
||||
ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config));
|
||||
|
||||
if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) {
|
||||
ORT_IGNORE_RETURN_VALUE(Save(
|
||||
config.model_with_loss_function_path.value(), SaveOption::NO_RELOAD));
|
||||
}
|
||||
|
||||
// derive actual set of weights to train
|
||||
std::unordered_set<std::string> weight_names_to_train =
|
||||
!filtered_config_weight_names_to_train.empty()
|
||||
|
|
|
|||
|
|
@ -194,10 +194,12 @@ class TrainingSession : public InferenceSession {
|
|||
struct GraphTransformerConfiguration {
|
||||
// Whether to enable GELU approximation which is faster but produces different results.
|
||||
bool enable_gelu_approximation{false};
|
||||
// Enable checkpointing of attention dropout to save memory
|
||||
bool attn_dropout_checkpoint{false};
|
||||
// Enable checkpointing of Gelu activation output to save memory
|
||||
bool gelu_checkpoint{false};
|
||||
// Enable recompute of attention dropout to save memory
|
||||
bool attn_dropout_recompute{false};
|
||||
// Enable recompute of Gelu activation output to save memory
|
||||
bool gelu_recompute{false};
|
||||
// Enable recompute of transformer layer ouput to save memory
|
||||
bool transformer_layer_recompute{false};
|
||||
};
|
||||
|
||||
GraphTransformerConfiguration graph_transformer_config{};
|
||||
|
|
|
|||
|
|
@ -167,9 +167,11 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
cxxopts::value<bool>()->default_value("true"))
|
||||
("enable_gelu_approximation", "Specify whether to enable GELU approximation.",
|
||||
cxxopts::value<bool>()->default_value("true"))
|
||||
("attn_dropout_checkpoint", "Enable checkpointing of attention dropout to save memory.",
|
||||
("attn_dropout_recompute", "Enable checkpointing of attention dropout to save memory.",
|
||||
cxxopts::value<bool>()->default_value("false"))
|
||||
("gelu_checkpoint", "Enable checkpointing of Gelu activation output to save memory.",
|
||||
("gelu_recompute", "Enable checkpointing of Gelu activation output to save memory.",
|
||||
cxxopts::value<bool>()->default_value("false"))
|
||||
("transformer_layer_recompute", "Enable checkpointing of transformer layer output to save memory.",
|
||||
cxxopts::value<bool>()->default_value("false"))
|
||||
("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
|
|
@ -458,8 +460,9 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet
|
|||
}
|
||||
|
||||
params.enable_gelu_approximation = flags["enable_gelu_approximation"].as<bool>();
|
||||
params.attn_dropout_checkpoint = flags["attn_dropout_checkpoint"].as<bool>();
|
||||
params.gelu_checkpoint = flags["gelu_checkpoint"].as<bool>();
|
||||
params.attn_dropout_recompute = flags["attn_dropout_recompute"].as<bool>();
|
||||
params.gelu_recompute = flags["gelu_recompute"].as<bool>();
|
||||
params.transformer_layer_recompute = flags["transformer_layer_recompute"].as<bool>();
|
||||
|
||||
ort_params.log_severity = static_cast<logging::Severity>(flags["ort_log_severity"].as<int>());
|
||||
ORT_RETURN_IF_NOT(
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ namespace training {
|
|||
static std::vector<FreeDimensionOverride> overrides = {};
|
||||
static SessionOptions SESSION_OPTION = {
|
||||
ExecutionMode::ORT_SEQUENTIAL, //execution_mode
|
||||
ExecutionOrder::PRIORITY_BASED, //execution_order
|
||||
false, //enable_profiling
|
||||
ORT_TSTR(""), //optimized_model_filepath
|
||||
true, //enable_mem_pattern
|
||||
|
|
@ -183,8 +184,9 @@ Status TrainingRunner::Initialize() {
|
|||
{
|
||||
TrainingSession::TrainingConfiguration::GraphTransformerConfiguration gt_config{};
|
||||
gt_config.enable_gelu_approximation = params_.enable_gelu_approximation;
|
||||
gt_config.attn_dropout_checkpoint = params_.attn_dropout_checkpoint;
|
||||
gt_config.gelu_checkpoint = params_.gelu_checkpoint;
|
||||
gt_config.attn_dropout_recompute = params_.attn_dropout_recompute;
|
||||
gt_config.gelu_recompute = params_.gelu_recompute;
|
||||
gt_config.transformer_layer_recompute = params_.transformer_layer_recompute;
|
||||
|
||||
config.graph_transformer_config = gt_config;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -175,10 +175,11 @@ class TrainingRunner {
|
|||
// Enable GELU approximation
|
||||
bool enable_gelu_approximation = false;
|
||||
// Enable checkpointing of attention dropout to save memory
|
||||
bool attn_dropout_checkpoint = false;
|
||||
bool attn_dropout_recompute = false;
|
||||
// Enable checkpointing of Gelu activation output to save memory
|
||||
bool gelu_checkpoint = false;
|
||||
|
||||
bool gelu_recompute = false;
|
||||
// Enable checkpointing of transformer layer output to save memory
|
||||
bool transformer_layer_recompute = false;
|
||||
// Use invertible layernorm grad
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue