mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
support pipeline partition with shared initializer (#4321)
* support bert partition with shared initializer * address feedback * address feedback * address feedback * add more test * remove bert-tiny model * address feedback * address function comment * move CreateNodeArg to graph_utils * rename function name * rename function name * fix windows build * fix windows type conversion warning * add function comment
This commit is contained in:
parent
1ebe598286
commit
7d96960ec8
9 changed files with 739 additions and 306 deletions
|
|
@ -140,6 +140,25 @@ class Node {
|
|||
return common::Status::OK();
|
||||
}
|
||||
|
||||
/**
|
||||
Helper to iterate through the container returned by #MutableInputDefs() or #MutableOutputDefs() and call the provided function.
|
||||
@param node_args Collection of NodeArgs returned by #MutableInputDefs() or #MutableOutputDefs()
|
||||
@param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
|
||||
and the index number in the container.
|
||||
@returns common::Status with success or error information.
|
||||
@remarks Returns immediately on error.
|
||||
*/
|
||||
static common::Status ForEachMutableWithIndex(std::vector<NodeArg*>& node_args,
|
||||
std::function<common::Status(NodeArg& arg, size_t index)> func) {
|
||||
for (size_t index = 0; index < node_args.size(); ++index) {
|
||||
auto arg = node_args[index];
|
||||
if (!arg->Exists())
|
||||
continue;
|
||||
ORT_RETURN_IF_ERROR(func(*arg, index));
|
||||
}
|
||||
return common::Status::OK();
|
||||
}
|
||||
|
||||
/** Gets the count of arguments for each of the Node's explicit inputs. */
|
||||
const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
|
||||
|
||||
|
|
@ -507,6 +526,9 @@ class Graph {
|
|||
/** Remove the initializer tensor with the provided name from the Graph. */
|
||||
void RemoveInitializedTensor(const std::string& tensor_name);
|
||||
|
||||
/** Check if a given name is an initializer tensor's name in this graph. */
|
||||
bool IsInitializedTensor(const std::string& name) const;
|
||||
|
||||
/** Replaces the initializer tensor with the same name as the given initializer tensor.
|
||||
The replacement initializer tensor must have the same type and shape as the existing initializer tensor.
|
||||
|
||||
|
|
@ -635,15 +657,14 @@ class Graph {
|
|||
if (iter != node_args_.end()) {
|
||||
return *(iter->second);
|
||||
}
|
||||
|
||||
auto result = node_args_.insert(std::make_pair(name, onnxruntime::make_unique<NodeArg>(name, p_arg_type)));
|
||||
return *(result.first->second);
|
||||
}
|
||||
|
||||
/** Generate a unique name.in this Graph for a NodeArg */
|
||||
/** Generate a unique name in this Graph for a NodeArg */
|
||||
std::string GenerateNodeArgName(const std::string& base_name);
|
||||
|
||||
/** Generate a unique name.in this Graph for a Node */
|
||||
/** Generate a unique name in this Graph for a Node */
|
||||
std::string GenerateNodeName(const std::string& base_name);
|
||||
|
||||
/** Add a Node to this Graph.
|
||||
|
|
|
|||
|
|
@ -253,6 +253,21 @@ Path& Path::Append(const Path& other) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
Path& Path::Concat(const PathString& string) {
|
||||
components_.back() += string;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Path& Path::ConcatIndex(const int index) {
|
||||
#ifdef _WIN32
|
||||
auto index_str = std::to_wstring(index);
|
||||
#else
|
||||
auto index_str = std::to_string(index);
|
||||
#endif
|
||||
components_.back() += index_str;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status RelativePath(const Path& src, const Path& dst, Path& rel) {
|
||||
ORT_RETURN_IF_NOT(
|
||||
src.GetRootPathString() == dst.GetRootPathString(),
|
||||
|
|
|
|||
|
|
@ -61,6 +61,19 @@ class Path {
|
|||
* The algorithm should model that of std::filesystem::path::append().
|
||||
*/
|
||||
Path& Append(const Path& other);
|
||||
|
||||
/**
|
||||
* Concatenates the current path and the argument string.
|
||||
* Unlike with Append() or operator/=, additional directory separators are never introduced.
|
||||
*/
|
||||
Path& Concat(const PathString& string);
|
||||
|
||||
/**
|
||||
* Concatenates an index by the end of current path.
|
||||
* Similar to Concat() except the argument is an index.
|
||||
*/
|
||||
Path& ConcatIndex(const int index);
|
||||
|
||||
/** Equivalent to this->Append(other). */
|
||||
Path& operator/=(const Path& other) {
|
||||
return Append(other);
|
||||
|
|
|
|||
|
|
@ -2331,6 +2331,10 @@ static void RemoveRepeatedFieldEntry(T& repeated_field, const TIter& entry_to_re
|
|||
}
|
||||
}
|
||||
|
||||
bool Graph::IsInitializedTensor(const std::string& name) const {
|
||||
return name_to_initial_tensor_.count(name) > 0;
|
||||
}
|
||||
|
||||
void Graph::RemoveInitializedTensor(const std::string& tensor_name) {
|
||||
bool found = false;
|
||||
auto iter = name_to_initial_tensor_.find(tensor_name);
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ static void RemoveGraphEdges(Graph& graph, const std::vector<GraphEdge>& edges)
|
|||
}
|
||||
|
||||
/** Given a graph, a list of edges, and a NodeArg name, checks if each of the edges provides an implicit input
|
||||
to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs.
|
||||
to a subgraph. If so, it checks if there is no clash of the given NodeArg name in each of the subgraphs.
|
||||
This is important when removing a node with this NodeArg as input. */
|
||||
static bool CanUpdateImplicitInputNameInSubgraphs(const Graph& graph,
|
||||
const std::vector<GraphEdge>& output_edges,
|
||||
|
|
@ -321,7 +321,7 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s
|
|||
return iter == attrs.end() ? nullptr : &iter->second;
|
||||
}
|
||||
|
||||
/** Checks for nodes with >= 1 outputs, if only one of the outputs is input to downstream Operators.
|
||||
/** Checks for nodes with >= 1 outputs, if only one of the outputs is input to downstream Operators.
|
||||
Returns the name of the single used output in output_name. */
|
||||
static bool IsOnlyOneOutputUsed(const Graph& graph, const Node& node, const std::string*& output_name) {
|
||||
const int unassigned = -1;
|
||||
|
|
@ -807,5 +807,10 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg) {
|
||||
return graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(base_arg.Name()), base_arg.TypeAsProto());
|
||||
}
|
||||
|
||||
} // namespace graph_utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -34,27 +34,27 @@ bool IsOutputUsed(const Node& node, int index);
|
|||
/** Returns true if the graph has the given input.*/
|
||||
bool IsGraphInput(const Graph& graph, const NodeArg* input);
|
||||
|
||||
/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true.
|
||||
/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true.
|
||||
@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
|
||||
*/
|
||||
bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope);
|
||||
|
||||
/** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
|
||||
/** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
|
||||
@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
|
||||
*/
|
||||
bool IsConstantInitializer(const Graph& graph, const std::string& name, bool check_outer_scope = true);
|
||||
|
||||
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
|
||||
/** returns the initializer's TensorProto if 'name' is an initializer, is constant and
|
||||
cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
|
||||
@param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
|
||||
*/
|
||||
const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const Graph& graph, const std::string& name,
|
||||
bool check_outer_scope = true);
|
||||
|
||||
/** Add a new initializer to 'graph'.
|
||||
/** Add a new initializer to 'graph'.
|
||||
Checks that new_initializer does not already exist in 'graph' before adding it.
|
||||
@returns The NodeArg for the new initializer.
|
||||
@remarks No matching graph input is created, so the initializer will be constant.
|
||||
@returns The NodeArg for the new initializer.
|
||||
@remarks No matching graph input is created, so the initializer will be constant.
|
||||
*/
|
||||
NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer);
|
||||
|
||||
|
|
@ -127,7 +127,7 @@ bool RemoveNode(Graph& graph, Node& node);
|
|||
|
||||
/** Tests if we can remove a node and replace its output with an initializer.
|
||||
Conditions:
|
||||
- Only one of the node's outputs is used by downstream operators or as a graph output
|
||||
- Only one of the node's outputs is used by downstream operators or as a graph output
|
||||
- multiple edges for the single used output are allowed
|
||||
- If the node produces a graph output the initializer_name must be the same as the node's output name
|
||||
- otherwise the required graph output will not be produced
|
||||
|
|
@ -140,20 +140,20 @@ bool CanReplaceNodeWithInitializer(const Graph& graph, const Node& node, const s
|
|||
See CanReplaceNodeWithInitializer for the conditions that must be satisfied in order to remove the node.*/
|
||||
bool ReplaceNodeWithInitializer(Graph& graph, Node& node, NodeArg& replacement);
|
||||
|
||||
/** Removes all output edges from the given Node of the Graph.
|
||||
/** Removes all output edges from the given Node of the Graph.
|
||||
This should probably be elevated to the Graph API eventually. */
|
||||
size_t RemoveNodeOutputEdges(Graph& graph, Node& node);
|
||||
|
||||
/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node',
|
||||
/** Replaces the input to nodes that are downstream from 'node', which was being provided by an output of 'node',
|
||||
with an output from a different node. Moves the output edges from 'node' for 'output_idx' to the replacement node.
|
||||
@param replacement The node providing the replacement output.
|
||||
@param replacement_output_idx The index of the output from 'replacement' to use.
|
||||
@param replacement_output_idx The index of the output from 'replacement' to use.
|
||||
|
||||
e.g. Node A produces outputs A1 and A2.
|
||||
Node B consumes A2 (edge between A and B for A2) and produces B1.
|
||||
e.g. Node A produces outputs A1 and A2.
|
||||
Node B consumes A2 (edge between A and B for A2) and produces B1.
|
||||
Node C consumes B1 (edge between B and C for B1).
|
||||
|
||||
If Node B was determined to not be needed, you would call ReplaceDownstreamNodeInput(graph, B, 0, A, 1)
|
||||
|
||||
If Node B was determined to not be needed, you would call ReplaceDownstreamNodeInput(graph, B, 0, A, 1)
|
||||
to replace B1 (output index 0 for node B) with A2 (output index 1 for node A) as input to the downstream node C.
|
||||
The edge that existed between B and C for B1 will be removed, and replaced with an edge between A and C for A2.
|
||||
*/
|
||||
|
|
@ -161,29 +161,29 @@ void ReplaceDownstreamNodeInput(Graph& graph, Node& node, int output_idx, Node&
|
|||
|
||||
/** Replace the input to a node with a NodeArg.
|
||||
@remarks The replacement only updates the node's input definition and does not create any edges,
|
||||
as typically this function is used to replace an input with an initializer or graph input
|
||||
as typically this function is used to replace an input with an initializer or graph input
|
||||
(there is no edge between an initializer or graph input and a Node).
|
||||
*/
|
||||
void ReplaceNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
|
||||
|
||||
/** Add an input to a node with a NodeArg for an initializer or graph input.
|
||||
@remarks target_input_idx must be the next input slot.
|
||||
e.g. if a Node has 2 inputs, AddNodeInput can only add input 3 and not 4.
|
||||
There is no edge between an initializer or graph input and a Node, so the replacement only updates the
|
||||
@remarks target_input_idx must be the next input slot.
|
||||
e.g. if a Node has 2 inputs, AddNodeInput can only add input 3 and not 4.
|
||||
There is no edge between an initializer or graph input and a Node, so the replacement only updates the
|
||||
node's input definition and does not create any new edges.
|
||||
*/
|
||||
void AddNodeInput(Node& target, int target_input_idx, NodeArg& new_input);
|
||||
|
||||
/** Finalize the fusion of second_node into first_node.
|
||||
/** Finalize the fusion of second_node into first_node.
|
||||
The output definitions and edges from the second_node are moved to first_node. second_node is deleted.
|
||||
e.g. Conv + Add fusion fuses the 'Add' into the Conv.
|
||||
*/
|
||||
void FinalizeNodeFusion(Graph& graph, Node& first_node, Node& second_node);
|
||||
|
||||
/** Finalize the fusion of two or more nodes which are being replaced with a single node.
|
||||
/** Finalize the fusion of two or more nodes which are being replaced with a single node.
|
||||
The first and last entries in 'nodes' are assumed to be the first and last nodes in a chain of nodes being fused.
|
||||
|
||||
Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names
|
||||
Conceptually multiple nodes are being combined into one, and post-fusion will produce output/s with the same names
|
||||
as the last node in 'nodes', and be connected to the same downstream nodes.
|
||||
|
||||
The input edges to the first node in 'nodes' will be moved to replacement_node. No other input edges are moved.
|
||||
|
|
@ -229,7 +229,7 @@ struct EdgeEndToMatch {
|
|||
@param edges_to_match has information of a sequence of adjacent edges in the path to be matched one by one.
|
||||
@param result stores edges that are found.
|
||||
@returns false when one edge has multiple candidates, or not all edges are found.
|
||||
@remarks matching an EdgeEndToMatch might get multiple candidates in output edges.
|
||||
@remarks matching an EdgeEndToMatch might get multiple candidates in output edges.
|
||||
When such case is encountered, this function will return false. This is by design to reduce complexity.
|
||||
Here is an example graph:
|
||||
Add
|
||||
|
|
@ -237,7 +237,7 @@ struct EdgeEndToMatch {
|
|||
Mul Mul
|
||||
\ /
|
||||
Sub
|
||||
For example, you want to match path from top to bottom: Add-->Mul-->Sub.
|
||||
For example, you want to match path from top to bottom: Add-->Mul-->Sub.
|
||||
When matching the first edge Add-->Mul, the algorithm found two matches.
|
||||
Then it returns false, and output a warning log entry.
|
||||
|
||||
|
|
@ -252,9 +252,15 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec
|
|||
|
||||
/**
|
||||
* Remove nodes with only one output edge using bottom-up bfs traversal.
|
||||
* @param node: The node to start with.
|
||||
* @param node: The node to start with.
|
||||
* @returns true if there is one or more node(s) removed by this function. Otherwise return false.
|
||||
*/
|
||||
bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& node);
|
||||
|
||||
/** Creates a mutable NodeArg owned by the graph with mirrored base_arg's TypeProto and name
|
||||
* @param base_arg The NodeArg the newly created NodeArg is mirrored based off.
|
||||
* @returns NodeArg reference that contains the same TypeProto info as base_arg with generated different names.
|
||||
*/
|
||||
NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg);
|
||||
} // namespace graph_utils
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -222,5 +222,34 @@ TEST(PathTest, RelativePathFailure) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(PathTest, Concat) {
|
||||
auto check_concat =
|
||||
[](const std::string& a, const std::string& b, const std::string& expected_a) {
|
||||
Path p_a{}, p_expected_a{};
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
|
||||
|
||||
EXPECT_EQ(p_a.Concat(ToPathString(b)).ToPathString(), p_expected_a.ToPathString());
|
||||
};
|
||||
|
||||
check_concat("/a/b", "c", "/a/bc");
|
||||
check_concat("a/b", "cd", "a/bcd");
|
||||
}
|
||||
|
||||
TEST(PathTest, ConcatIndex) {
|
||||
auto check_concat_index =
|
||||
[](const std::string& a, const int i, const std::string& expected_a) {
|
||||
Path p_a{}, p_expected_a{};
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(a), p_a));
|
||||
ASSERT_STATUS_OK(Path::Parse(ToPathString(expected_a), p_expected_a));
|
||||
|
||||
EXPECT_EQ(p_a.ConcatIndex(i).ToPathString(), p_expected_a.ToPathString());
|
||||
};
|
||||
|
||||
check_concat_index("/a/b", 0, "/a/b0");
|
||||
check_concat_index("a/b", 123, "a/b123");
|
||||
check_concat_index("a/b", -1, "a/b-1");
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@
|
|||
#include "orttraining/core/graph/pipeline_transformer.h"
|
||||
#include <queue>
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
|
||||
using namespace onnxruntime::common;
|
||||
using namespace onnxruntime::graph_utils;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
|
@ -78,17 +81,15 @@ std::vector<NodeArg*> FindBackwardLeafNodes(Graph& graph) {
|
|||
// The newly created boolean scalar may be appended to signal_args. If
|
||||
// signal_args is empty, the source of signal_args[i] would be tensor_args[i].
|
||||
void ConvertTensorToBoolSignal(
|
||||
Graph& graph,
|
||||
const std::vector<NodeArg*>& tensor_args,
|
||||
std::vector<NodeArg*>& signal_args) {
|
||||
|
||||
for (auto tensor_arg: tensor_args) {
|
||||
Graph& graph,
|
||||
const std::vector<NodeArg*>& tensor_args,
|
||||
std::vector<NodeArg*>& signal_args) {
|
||||
for (auto tensor_arg : tensor_args) {
|
||||
// Declare the scalar signal this "tensor_arg" will be converted into.
|
||||
auto signal_arg = &CreateTypedNodeArg(
|
||||
graph,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
"signal_" + tensor_arg->Name()
|
||||
);
|
||||
graph,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
"signal_" + tensor_arg->Name());
|
||||
|
||||
// Add the new scalar to user-specified vector.
|
||||
signal_args.push_back(signal_arg);
|
||||
|
|
@ -98,37 +99,28 @@ void ConvertTensorToBoolSignal(
|
|||
std::vector<NodeArg*> input_args{tensor_arg};
|
||||
std::vector<NodeArg*> output_args{signal_arg};
|
||||
graph.AddNode(
|
||||
name,
|
||||
"Group",
|
||||
"",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
name,
|
||||
"Group",
|
||||
"",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
}
|
||||
}
|
||||
|
||||
// Return mirror variables for node_arg with a different name.
|
||||
NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
|
||||
const auto& new_name = graph.GenerateNodeArgName(base_arg->Name());
|
||||
ONNX_NAMESPACE::TypeProto type_proto(*(base_arg->TypeAsProto()));
|
||||
if (graph.GetNodeArg(new_name) != nullptr) {
|
||||
ORT_THROW("Node with name ", new_name, " already exists.");
|
||||
}
|
||||
return graph.GetOrCreateNodeArg(new_name, &type_proto);
|
||||
}
|
||||
|
||||
// Return mirror variables for node_args.
|
||||
// The i-th output element mirrors node_args[i] but with a different name.
|
||||
std::vector<NodeArg*> CreateMirrorNodeArgs(
|
||||
Graph& graph,
|
||||
const std::vector<NodeArg*>& node_args) {
|
||||
Graph& graph,
|
||||
const std::vector<NodeArg*>& node_args) {
|
||||
// Declare output.
|
||||
std::vector<NodeArg*> new_node_args;
|
||||
|
||||
for (auto& node_arg: node_args) {
|
||||
for (auto& node_arg : node_args) {
|
||||
// new_node_arg is a mirror variable of node_arg. They have the same type.
|
||||
auto new_node_arg = &CreateNodeArg(graph, node_arg);
|
||||
assert(node_arg);
|
||||
auto new_node_arg = &CreateNodeArg(graph, *node_arg);
|
||||
new_node_args.push_back(new_node_arg);
|
||||
}
|
||||
|
||||
|
|
@ -138,33 +130,33 @@ std::vector<NodeArg*> CreateMirrorNodeArgs(
|
|||
// Create a node with input schema [event, input1, input2, ..., inputN] and
|
||||
// output schema [input1, input2, ..., inputN]
|
||||
Node& CreateBottleneckNode(Graph& graph,
|
||||
const std::string& op_type,
|
||||
const std::string& op_name,
|
||||
const std::string& description,
|
||||
NodeArg* event,
|
||||
std::vector<NodeArg*> input_node_args,
|
||||
std::vector<NodeArg*> output_node_args) {
|
||||
const std::string& op_type,
|
||||
const std::string& op_name,
|
||||
const std::string& description,
|
||||
NodeArg* event,
|
||||
std::vector<NodeArg*> input_node_args,
|
||||
std::vector<NodeArg*> output_node_args) {
|
||||
const auto name = graph.GenerateNodeName(op_name);
|
||||
if (event) {
|
||||
input_node_args.insert(input_node_args.begin(), event);
|
||||
}
|
||||
|
||||
return graph.AddNode(
|
||||
name,
|
||||
op_type,
|
||||
description,
|
||||
input_node_args,
|
||||
output_node_args,
|
||||
nullptr /* assume all bottleneck node have no attributes */,
|
||||
kMSDomain);
|
||||
name,
|
||||
op_type,
|
||||
description,
|
||||
input_node_args,
|
||||
output_node_args,
|
||||
nullptr /* assume all bottleneck node have no attributes */,
|
||||
kMSDomain);
|
||||
}
|
||||
|
||||
Node* AddBackwardRecord(Graph& graph,
|
||||
Node* backward_send,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::vector<std::string>& new_output_names,
|
||||
std::string &event_id_tensor_name,
|
||||
std::string &output_tensor_name) {
|
||||
std::string& event_id_tensor_name,
|
||||
std::string& output_tensor_name) {
|
||||
std::vector<NodeArg*> input_args;
|
||||
AddNewNodeArg(graph, "backward_recorded_event_id", ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
input_args, new_input_names);
|
||||
|
|
@ -194,13 +186,16 @@ Node* AddBackwardRecord(Graph& graph,
|
|||
// Optimizer will be added after applying pipeline transformer. To support partial graph evaluation,
|
||||
// the added Record backward op will have its first passthrough input as output.
|
||||
ORT_ENFORCE(input_args.size() >= 2, "RecordEvent backward op at least have two inputs.");
|
||||
auto& new_output = CreateNodeArg(graph, input_args[1]); // the first input is signal, not passing through
|
||||
|
||||
// RecordEvent doesn't have optional input, so it cannot be nullptr.
|
||||
assert(input_args[1]);
|
||||
auto& new_output = CreateNodeArg(graph, *input_args[1]); // the first input is signal, not passing through
|
||||
output_args.push_back(&new_output);
|
||||
new_output_names.push_back(new_output.Name());
|
||||
|
||||
Node* record_node = &CreateBottleneckNode(
|
||||
graph, "RecordEvent", "backward_record", "Backward pass", nullptr,
|
||||
input_args, output_args);
|
||||
graph, "RecordEvent", "backward_record", "Backward pass", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
// First input argument is the recorded event ID tensor.
|
||||
event_id_tensor_name = input_args.front()->Name();
|
||||
|
|
@ -212,21 +207,19 @@ Node* AddBackwardRecord(Graph& graph,
|
|||
}
|
||||
|
||||
Node* AddForwardWait(Graph& graph,
|
||||
Node* /* forward_recv */,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::string& forward_waited_event_name,
|
||||
std::string& output_tensor_name) {
|
||||
Node* /* forward_recv */,
|
||||
std::vector<std::string>& new_input_names,
|
||||
std::string& forward_waited_event_name,
|
||||
std::string& output_tensor_name) {
|
||||
// Append old_input to input_args and return its pass-through value. Note that
|
||||
// input_args and output_args are Wait's inputs and outputs, respectively.
|
||||
auto update_wait_input_output = [&](NodeArg* old_input,
|
||||
std::vector<NodeArg*>& input_args,
|
||||
std::vector<NodeArg*>& output_args) -> NodeArg& {
|
||||
assert(old_input);
|
||||
input_args.push_back(old_input);
|
||||
|
||||
const auto& new_name = graph.GenerateNodeArgName(old_input->Name());
|
||||
ONNX_NAMESPACE::TypeProto type_proto(*(old_input->TypeAsProto()));
|
||||
|
||||
auto& wait_output = graph.GetOrCreateNodeArg(new_name, &type_proto);
|
||||
auto& wait_output = CreateNodeArg(graph, *old_input);
|
||||
output_args.push_back(&wait_output);
|
||||
|
||||
return wait_output;
|
||||
|
|
@ -257,8 +250,8 @@ Node* AddForwardWait(Graph& graph,
|
|||
}
|
||||
|
||||
Node* wait_node = &CreateBottleneckNode(
|
||||
graph, "WaitEvent", "backward_record", "", nullptr,
|
||||
input_args, output_args);
|
||||
graph, "WaitEvent", "backward_record", "", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
forward_waited_event_name = input_args.front()->Name();
|
||||
output_tensor_name = output_args.front()->Name();
|
||||
|
|
@ -276,13 +269,15 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
|
|||
std::string& backward_waited_event_name,
|
||||
std::string& forward_output_name,
|
||||
std::string& backward_output_name) {
|
||||
if (!forward_send != !backward_recv){
|
||||
ORT_THROW("Graph requires either having both send forward node "
|
||||
"and recv backword node, or none of them. Currently the graph "
|
||||
"has send forward: ", forward_send, " and recv backward: ", backward_recv);
|
||||
if (!forward_send != !backward_recv) {
|
||||
ORT_THROW(
|
||||
"Graph requires either having both send forward node "
|
||||
"and recv backward node, or none of them. Currently the graph "
|
||||
"has send forward: ",
|
||||
forward_send, " and recv backward: ", backward_recv);
|
||||
}
|
||||
|
||||
if (!forward_send && !backward_recv){
|
||||
if (!forward_send && !backward_recv) {
|
||||
// Last partition doesn't have send forwrad and recv backward. No insert
|
||||
// needed.
|
||||
return Status::OK();
|
||||
|
|
@ -302,14 +297,16 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
|
|||
|
||||
// Add send forward op's output as record op's input and output
|
||||
for (auto& output : forward_send->MutableOutputDefs()) {
|
||||
auto& new_output = CreateNodeArg(graph, output);
|
||||
// send doesn't have optional output, so the node cannot be nullptr.
|
||||
assert(output);
|
||||
auto& new_output = CreateNodeArg(graph, *output);
|
||||
output_args.push_back(&new_output);
|
||||
input_args.push_back(output);
|
||||
}
|
||||
|
||||
record_node = &CreateBottleneckNode(
|
||||
graph, "RecordEvent", "forward_record", "", nullptr,
|
||||
input_args, output_args);
|
||||
graph, "RecordEvent", "forward_record", "", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
forward_recorded_event_name = record_node->InputDefs()[0]->Name();
|
||||
forward_output_name = record_node->OutputDefs()[0]->Name();
|
||||
|
|
@ -327,13 +324,16 @@ Status AddOrSkipForwardRecordBackwardWait(Graph& graph,
|
|||
std::end(record_node->MutableOutputDefs()));
|
||||
|
||||
auto& input = backward_recv->MutableInputDefs()[0];
|
||||
auto& new_output = CreateNodeArg(graph, input);
|
||||
|
||||
// recv node doesn't have optional input, so the node cannot be nullptr.
|
||||
assert(input);
|
||||
auto& new_output = CreateNodeArg(graph, *input);
|
||||
output_args.push_back(&new_output);
|
||||
input = &new_output;
|
||||
|
||||
wait_node = &CreateBottleneckNode(
|
||||
graph, "WaitEvent", "backward_wait", "Backward pass", nullptr,
|
||||
input_args, output_args);
|
||||
graph, "WaitEvent", "backward_wait", "Backward pass", nullptr,
|
||||
input_args, output_args);
|
||||
|
||||
backward_waited_event_name = wait_node->InputDefs()[0]->Name();
|
||||
backward_output_name = wait_node->OutputDefs()[0]->Name();
|
||||
|
|
@ -352,8 +352,8 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
|
|||
ORT_ENFORCE(node_args[i]->Name() != new_node_args[i]->Name());
|
||||
ORT_ENFORCE(node_args[i]->Type() == new_node_args[i]->Type());
|
||||
|
||||
for (auto& node: nodes) {
|
||||
for (auto& node_arg: node->MutableInputDefs()) {
|
||||
for (auto& node : nodes) {
|
||||
for (auto& node_arg : node->MutableInputDefs()) {
|
||||
// Only replace when node's input name matches node_args[i].
|
||||
if (node_arg->Name().compare(node_args[i]->Name()) != 0) {
|
||||
continue;
|
||||
|
|
@ -371,11 +371,11 @@ void ReplaceNodeArgs(std::vector<Node*>& nodes,
|
|||
}
|
||||
|
||||
std::string AddEventBeforeNode(
|
||||
Graph& graph,
|
||||
Node* node,
|
||||
const std::string& event_op_type,
|
||||
const std::string& event_op_name,
|
||||
const std::string& event_id_name) {
|
||||
Graph& graph,
|
||||
Node* node,
|
||||
const std::string& event_op_type,
|
||||
const std::string& event_op_name,
|
||||
const std::string& event_id_name) {
|
||||
if (!node) {
|
||||
// No event operator is be inserted, so we don't have its input event name.
|
||||
return "";
|
||||
|
|
@ -410,11 +410,11 @@ std::string AddEventBeforeNode(
|
|||
}
|
||||
|
||||
std::string AddEventAfterNode(
|
||||
Graph& graph,
|
||||
Node* node,
|
||||
const std::string& event_op_type,
|
||||
const std::string& event_op_name,
|
||||
const std::string& event_id_name) {
|
||||
Graph& graph,
|
||||
Node* node,
|
||||
const std::string& event_op_type,
|
||||
const std::string& event_op_name,
|
||||
const std::string& event_id_name) {
|
||||
if (!node) {
|
||||
// No event operator is be inserted, so we don't have its input event name.
|
||||
return "";
|
||||
|
|
@ -431,7 +431,7 @@ std::string AddEventAfterNode(
|
|||
for (size_t i = 0; i < node_args.size(); ++i) {
|
||||
// Find consumer of "node"'s i-th output.
|
||||
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(
|
||||
node_args[i]->Name());
|
||||
node_args[i]->Name());
|
||||
// Replace node_args[i] with new_node_args[i] in nodes.
|
||||
ReplaceNodeArgs(consumer_nodes, {node_args[i]}, {new_node_args[i]});
|
||||
}
|
||||
|
|
@ -457,9 +457,9 @@ Status AddForwardWaitAfterRecv(
|
|||
std::vector<std::string>& new_input_names,
|
||||
std::string& event_name) {
|
||||
event_name = AddEventAfterNode(
|
||||
graph, comm_node,
|
||||
"WaitEvent", "forward_wait_after_recv",
|
||||
"forward_wait_after_recv_event_id");
|
||||
graph, comm_node,
|
||||
"WaitEvent", "forward_wait_after_recv",
|
||||
"forward_wait_after_recv_event_id");
|
||||
if (event_name.empty()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
|
|
@ -474,9 +474,9 @@ Status AddForwardRecordBeforeSend(
|
|||
std::vector<std::string>& new_input_names,
|
||||
std::string& event_name) {
|
||||
event_name = AddEventBeforeNode(
|
||||
graph, comm_node,
|
||||
"RecordEvent", "forward_record_before_send",
|
||||
"forward_record_before_send_event_id");
|
||||
graph, comm_node,
|
||||
"RecordEvent", "forward_record_before_send",
|
||||
"forward_record_before_send_event_id");
|
||||
if (event_name.empty()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
|
|
@ -491,9 +491,9 @@ Status AddBackwardWaitAfterRecv(
|
|||
std::vector<std::string>& new_input_names,
|
||||
std::string& event_name) {
|
||||
event_name = AddEventAfterNode(
|
||||
graph, comm_node,
|
||||
"WaitEvent", "backward_wait_after_recv",
|
||||
"backward_wait_after_recv_event_id");
|
||||
graph, comm_node,
|
||||
"WaitEvent", "backward_wait_after_recv",
|
||||
"backward_wait_after_recv_event_id");
|
||||
if (event_name.empty()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
|
|
@ -508,9 +508,9 @@ Status AddBackwardRecordBeforeSend(
|
|||
std::vector<std::string>& new_input_names,
|
||||
std::string& event_name) {
|
||||
event_name = AddEventBeforeNode(
|
||||
graph, comm_node,
|
||||
"RecordEvent", "backward_record_before_send",
|
||||
"backward_record_before_send_event_id");
|
||||
graph, comm_node,
|
||||
"RecordEvent", "backward_record_before_send",
|
||||
"backward_record_before_send_event_id");
|
||||
if (event_name.empty()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
|
|
@ -605,20 +605,20 @@ Status SetInputsOutputsAndResolve(Graph& graph,
|
|||
// Record-2: Tell others that backward pass is done.
|
||||
// Record-3: Tell others that backward result has been passed to another stage.
|
||||
Status TransformGraphForPipeline(
|
||||
Graph& graph,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
std::string& forward_waited_event_name,
|
||||
std::string& forward_recorded_event_name,
|
||||
std::string& backward_waited_event_name,
|
||||
std::string& backward_recorded_event_name,
|
||||
std::string& forward_wait_output_name,
|
||||
std::string& forward_record_output_name,
|
||||
std::string& backward_wait_output_name,
|
||||
std::string& backward_record_output_name,
|
||||
std::string& forward_waited_event_after_recv_name,
|
||||
std::string& forward_recorded_event_before_send_name,
|
||||
std::string& backward_waited_event_after_recv_name,
|
||||
std::string& backward_recorded_event_before_send_name) {
|
||||
Graph& graph,
|
||||
const std::unordered_set<std::string>& weights_to_train,
|
||||
std::string& forward_waited_event_name,
|
||||
std::string& forward_recorded_event_name,
|
||||
std::string& backward_waited_event_name,
|
||||
std::string& backward_recorded_event_name,
|
||||
std::string& forward_wait_output_name,
|
||||
std::string& forward_record_output_name,
|
||||
std::string& backward_wait_output_name,
|
||||
std::string& backward_record_output_name,
|
||||
std::string& forward_waited_event_after_recv_name,
|
||||
std::string& forward_recorded_event_before_send_name,
|
||||
std::string& backward_waited_event_after_recv_name,
|
||||
std::string& backward_recorded_event_before_send_name) {
|
||||
// Declare nodes according to their topological order.
|
||||
Node* forward_wait{nullptr};
|
||||
Node* forward_send{nullptr};
|
||||
|
|
@ -650,27 +650,27 @@ Status TransformGraphForPipeline(
|
|||
std::vector<std::string> new_output_names;
|
||||
|
||||
backward_record = AddBackwardRecord(
|
||||
graph,
|
||||
backward_send,
|
||||
new_input_names,
|
||||
new_output_names,
|
||||
backward_recorded_event_name,
|
||||
backward_record_output_name);
|
||||
graph,
|
||||
backward_send,
|
||||
new_input_names,
|
||||
new_output_names,
|
||||
backward_recorded_event_name,
|
||||
backward_record_output_name);
|
||||
forward_wait = AddForwardWait(
|
||||
graph,
|
||||
forward_recv,
|
||||
new_input_names,
|
||||
forward_waited_event_name,
|
||||
forward_wait_output_name);
|
||||
graph,
|
||||
forward_recv,
|
||||
new_input_names,
|
||||
forward_waited_event_name,
|
||||
forward_wait_output_name);
|
||||
ORT_RETURN_IF_ERROR(AddOrSkipForwardRecordBackwardWait(
|
||||
graph,
|
||||
forward_send,
|
||||
backward_recv,
|
||||
new_input_names,
|
||||
forward_recorded_event_name,
|
||||
backward_waited_event_name,
|
||||
forward_record_output_name,
|
||||
backward_wait_output_name));
|
||||
graph,
|
||||
forward_send,
|
||||
backward_recv,
|
||||
new_input_names,
|
||||
forward_recorded_event_name,
|
||||
backward_waited_event_name,
|
||||
forward_record_output_name,
|
||||
backward_wait_output_name));
|
||||
|
||||
// Different stages have different patterns of Send & Recv.
|
||||
// For different patterns, we add different WaitEvent and Record.
|
||||
|
|
@ -694,11 +694,11 @@ Status TransformGraphForPipeline(
|
|||
// One and only one of is_first_stage, is_middle_stage, and is_last_stage can be true.
|
||||
const unsigned int stage_flag_sum = is_first_stage + is_middle_stage + is_last_stage;
|
||||
ORT_RETURN_IF_NOT(stage_flag_sum == 1u,
|
||||
"The processed graph should be classified into a stage, "
|
||||
"but we see more than one true's in the following statements. ",
|
||||
"Is first stage? ", is_first_stage, ". ",
|
||||
"Is middle stage? ", is_middle_stage, ". ",
|
||||
"Is last stage? ", is_last_stage, ".");
|
||||
"The processed graph should be classified into a stage, "
|
||||
"but we see more than one true's in the following statements. ",
|
||||
"Is first stage? ", is_first_stage, ". ",
|
||||
"Is middle stage? ", is_middle_stage, ". ",
|
||||
"Is last stage? ", is_last_stage, ".");
|
||||
|
||||
// Now, we add Wait's in parentheses shown below.
|
||||
// 1. First stage:
|
||||
|
|
@ -713,19 +713,17 @@ Status TransformGraphForPipeline(
|
|||
if (is_first_stage) {
|
||||
// If first stage, insert after forward WaitEvent.
|
||||
ORT_RETURN_IF_ERROR(AddForwardWaitAfterRecv(
|
||||
graph,
|
||||
forward_wait,
|
||||
new_input_names,
|
||||
forward_waited_event_after_recv_name
|
||||
));
|
||||
graph,
|
||||
forward_wait,
|
||||
new_input_names,
|
||||
forward_waited_event_after_recv_name));
|
||||
} else if (is_middle_stage || is_last_stage) {
|
||||
// If middle stage or last stage, insert after forward Recv.
|
||||
ORT_RETURN_IF_ERROR(AddForwardWaitAfterRecv(
|
||||
graph,
|
||||
forward_recv,
|
||||
new_input_names,
|
||||
forward_waited_event_after_recv_name
|
||||
));
|
||||
graph,
|
||||
forward_recv,
|
||||
new_input_names,
|
||||
forward_waited_event_after_recv_name));
|
||||
}
|
||||
|
||||
// Now, we add Record's in parentheses shown below.
|
||||
|
|
@ -740,11 +738,10 @@ Status TransformGraphForPipeline(
|
|||
// ----------------------> BW -> Record -> Send -> Record
|
||||
if (is_first_stage || is_middle_stage) {
|
||||
ORT_RETURN_IF_ERROR(AddForwardRecordBeforeSend(
|
||||
graph,
|
||||
forward_send,
|
||||
new_input_names,
|
||||
forward_recorded_event_before_send_name
|
||||
));
|
||||
graph,
|
||||
forward_send,
|
||||
new_input_names,
|
||||
forward_recorded_event_before_send_name));
|
||||
}
|
||||
|
||||
// Now, we add Wait's in parentheses shown below.
|
||||
|
|
@ -759,11 +756,10 @@ Status TransformGraphForPipeline(
|
|||
// ----------------------> BW -> Record -> Send -> Record
|
||||
if (is_first_stage || is_middle_stage) {
|
||||
ORT_RETURN_IF_ERROR(AddBackwardWaitAfterRecv(
|
||||
graph,
|
||||
backward_recv,
|
||||
new_input_names,
|
||||
backward_waited_event_after_recv_name
|
||||
));
|
||||
graph,
|
||||
backward_recv,
|
||||
new_input_names,
|
||||
backward_waited_event_after_recv_name));
|
||||
}
|
||||
|
||||
// Now, we add Record's in parentheses shown below.
|
||||
|
|
@ -778,18 +774,16 @@ Status TransformGraphForPipeline(
|
|||
// ----------------------> BW -> (Record) -> Send -> Record
|
||||
if (is_first_stage) {
|
||||
ORT_RETURN_IF_ERROR(AddBackwardRecordBeforeSend(
|
||||
graph,
|
||||
backward_record,
|
||||
new_input_names,
|
||||
backward_recorded_event_before_send_name
|
||||
));
|
||||
graph,
|
||||
backward_record,
|
||||
new_input_names,
|
||||
backward_recorded_event_before_send_name));
|
||||
} else if (is_middle_stage || is_last_stage) {
|
||||
ORT_RETURN_IF_ERROR(AddBackwardRecordBeforeSend(
|
||||
graph,
|
||||
backward_send,
|
||||
new_input_names,
|
||||
backward_recorded_event_before_send_name
|
||||
));
|
||||
graph,
|
||||
backward_send,
|
||||
new_input_names,
|
||||
backward_recorded_event_before_send_name));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(SetInputsOutputsAndResolve(graph, weights_to_train, new_input_names, new_output_names));
|
||||
|
|
@ -801,11 +795,11 @@ Status TransformGraphForPipeline(
|
|||
// It also cerates an initializer to store its value.
|
||||
template <typename T>
|
||||
void AddNewScalarNodeArgAndInitializer(Graph& graph,
|
||||
const std::string& op_name,
|
||||
onnx::TensorProto_DataType type,
|
||||
T data,
|
||||
std::vector<NodeArg*>& new_node_args,
|
||||
std::vector<std::string>& new_names) {
|
||||
const std::string& op_name,
|
||||
onnx::TensorProto_DataType type,
|
||||
T data,
|
||||
std::vector<NodeArg*>& new_node_args,
|
||||
std::vector<std::string>& new_names) {
|
||||
AddNewNodeArg(graph, op_name, type, new_node_args, new_names);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto proto_data;
|
||||
|
|
@ -825,6 +819,218 @@ void AddNewScalarNodeArgAndInitializer(Graph& graph,
|
|||
graph.AddInitializedTensor(proto_data);
|
||||
}
|
||||
|
||||
// Given a node, this function finds all its connected nodes (consumer nodes and producer nodes) and
|
||||
// connected inputs and outputs in the given graph, then adds them to the containers passed in.
|
||||
Status FindAllConnectedNodes(Graph& graph,
|
||||
Node* node,
|
||||
std::vector<Node*>& connected_nodes,
|
||||
std::set<NodeArg*>& connected_inputs,
|
||||
std::set<NodeArg*>& connected_outputs
|
||||
) {
|
||||
assert(node);
|
||||
ORT_THROW_IF_ERROR(node->ForEachMutableWithIndex(
|
||||
node->MutableInputDefs(),
|
||||
[&](NodeArg& node_arg, size_t /*index*/) {
|
||||
if (graph.IsInputsIncludingInitializers(&node_arg) || graph.IsInitializedTensor(node_arg.Name())) {
|
||||
connected_inputs.insert(&node_arg);
|
||||
} else {
|
||||
Node* producer_node = graph.GetMutableProducerNode(node_arg.Name());
|
||||
if (producer_node == nullptr) {
|
||||
// got nullptr as producer node. This could be because the input is a constant op which will be optimized
|
||||
// away. Print out this information and continue.
|
||||
// TODO: re-visit the different cases to see if there are other situations aside from constant ops.
|
||||
LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node.";
|
||||
} else {
|
||||
connected_nodes.push_back(producer_node);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
ORT_THROW_IF_ERROR(node->ForEachMutableWithIndex(
|
||||
node->MutableOutputDefs(),
|
||||
[&](NodeArg& node_arg, size_t /*index*/) {
|
||||
if (!graph.IsOutput(&node_arg)) {
|
||||
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(node_arg.Name());
|
||||
connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end());
|
||||
|
||||
} else {
|
||||
connected_outputs.insert(&node_arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// PipelineStageNodeGroup groups nodes that share the same input initializer and belong to the same stage.
|
||||
// It is used to distinguish other nodes that share the same input initializer but belong to
|
||||
// other pipeline partitions after split.
|
||||
struct PipelineStageNodeGroup {
|
||||
const size_t stage_id;
|
||||
|
||||
// Vector of nodes that have the same initializer input and belong to the same stage. Noted that
|
||||
// the consumer nodes of a particular initializer can be more than one, so we need a vector to store those
|
||||
// nodes.
|
||||
std::vector<Node*> nodes;
|
||||
PipelineStageNodeGroup(const size_t stage, std::vector<Node*>& node_group) : stage_id(stage), nodes(std::move(node_group)){};
|
||||
};
|
||||
|
||||
// This function passes through the given initializer across stages specified in node_groups[i].stage_id.
|
||||
// This applies to the case when initializer is used in multiple stages, say stage a and stage b (a<b). We will
|
||||
// keep the initializer in stage a and pass it down to b through the send nodes and recv nodes.
|
||||
common::Status AddPassthroughInitializer(Graph& graph,
|
||||
NodeArg* initializer,
|
||||
const std::vector<PipelineStageNodeGroup>& node_groups,
|
||||
const std::vector<Node*>& send_nodes,
|
||||
const std::vector<Node*>& recv_nodes) {
|
||||
assert(initializer);
|
||||
ORT_ENFORCE(node_groups.size() >= 2, "Initializer ", initializer->Name(), " is not shared across stages.");
|
||||
|
||||
const size_t from_stage = node_groups.front().stage_id;
|
||||
const size_t to_stage = node_groups.back().stage_id;
|
||||
|
||||
ORT_ENFORCE(from_stage < to_stage, "Pass through from_stage (", from_stage,
|
||||
") is not less than the to_stage (", to_stage, ").");
|
||||
|
||||
auto dtype = initializer->TypeAsProto()->tensor_type().elem_type();
|
||||
|
||||
auto current_node_arg = initializer;
|
||||
|
||||
size_t node_group_index = 1;
|
||||
for (auto i = from_stage; i < to_stage; ++i) {
|
||||
// processing send node in cut i
|
||||
auto& send_attributes = send_nodes[i]->GetMutableAttributes();
|
||||
auto& send_element_types = send_attributes["element_types"];
|
||||
send_element_types.add_ints(static_cast<int64_t>(dtype));
|
||||
send_nodes[i]->MutableInputDefs().push_back(current_node_arg);
|
||||
send_nodes[i]->MutableInputArgsCount().back()++;
|
||||
|
||||
// Create a new node_arg for the recv, as the new node_arg from recv node should possess a different id
|
||||
// than the one in send
|
||||
assert(current_node_arg);
|
||||
current_node_arg = &CreateNodeArg(graph, *current_node_arg);
|
||||
|
||||
// process recv node in cut i
|
||||
auto& recv_attributes = recv_nodes[i]->GetMutableAttributes();
|
||||
auto& recv_element_types = recv_attributes["element_types"];
|
||||
recv_element_types.add_ints(static_cast<int64_t>(dtype));
|
||||
recv_nodes[i]->MutableOutputDefs().push_back(current_node_arg);
|
||||
|
||||
// update the consumer node's input if the node's group is not in the first partition
|
||||
if (i > from_stage && node_groups[node_group_index].stage_id == (i + 1)) {
|
||||
for (auto node : node_groups[node_group_index].nodes) {
|
||||
for (auto& input_node : node->MutableInputDefs()) {
|
||||
if (input_node == initializer) {
|
||||
input_node = current_node_arg;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
node_group_index++;
|
||||
}
|
||||
}
|
||||
|
||||
ORT_ENFORCE(node_group_index == node_groups.size(), "Not all nodes are updated with new initializer.");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Traverse the graph to find out all connected elements in the graph from start_node. The traverse treats the graph as an
|
||||
// undirected graph.
|
||||
void TraverseGraphWithConnectedElement(Graph& graph,
|
||||
Node* start_node,
|
||||
std::set<Node*>& visited_nodes,
|
||||
std::set<NodeArg*>& visited_inputs,
|
||||
std::set<NodeArg*>& visited_outputs) {
|
||||
assert(start_node);
|
||||
visited_nodes.clear();
|
||||
visited_inputs.clear();
|
||||
visited_outputs.clear();
|
||||
|
||||
std::queue<Node*> node_queue;
|
||||
node_queue.push(start_node);
|
||||
|
||||
while (!node_queue.empty()) {
|
||||
auto node = node_queue.front();
|
||||
node_queue.pop();
|
||||
if (visited_nodes.insert(node).second) {
|
||||
std::vector<Node*> connected_nodes;
|
||||
ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs));
|
||||
|
||||
for (auto n : connected_nodes) {
|
||||
ORT_ENFORCE(n != nullptr, "Found nullptr in searching for connected nodes");
|
||||
node_queue.push(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If an initializer is shared across partitions, instead of creating a separate all_reduce op to
|
||||
// sync with those tensors in selected partitions, we save only one copy of that initializer in
|
||||
// the very first partition it appears, and pass that data down to all following partitions
|
||||
// where this initializer is used.
|
||||
common::Status HandleSharedInitializer(Graph& graph,
|
||||
const std::vector<Node*>& send_nodes,
|
||||
const std::vector<Node*>& recv_nodes) {
|
||||
// Map a given initializer to all the partitions that its consumer nodes reside. The size of
|
||||
// the mapped vector reflects how many partitions this initializer's consumer nodes distribute.
|
||||
// If its size is greater than 1, it means this initializer is being used in more than one partition and
|
||||
// we need to proceed those cases.
|
||||
std::map<NodeArg*, std::vector<PipelineStageNodeGroup>> input_consumer_stage_map;
|
||||
|
||||
for (size_t stage = 0; stage <= send_nodes.size(); ++stage) {
|
||||
std::set<Node*> visited_nodes;
|
||||
std::set<NodeArg*> visited_inputs;
|
||||
std::set<NodeArg*> visited_outputs;
|
||||
|
||||
// send_nodes[i] is the Send op in i-th stage's forward pass. recv_nodes[i] is the Recv in the (i+1)-th stage's
|
||||
// forward pass. When not in last stage, traverse start from send node; otherwise start with the recv node as
|
||||
// send node doesn't exist in last partition's forward pass.
|
||||
Node* traverse_start_node = stage < send_nodes.size() ? send_nodes[stage] : recv_nodes.back();
|
||||
TraverseGraphWithConnectedElement(graph,
|
||||
traverse_start_node,
|
||||
visited_nodes,
|
||||
visited_inputs,
|
||||
visited_outputs);
|
||||
|
||||
for (const auto input : visited_inputs) {
|
||||
// If the node is an input instead of an initializer, continue
|
||||
if (!graph.IsInitializedTensor(input->Name())){
|
||||
continue;
|
||||
}
|
||||
|
||||
// group all consumer nodes that shares the same input initializer in visited_consumer_nodes
|
||||
std::vector<Node*> consumer_nodes = graph.GetMutableConsumerNodes(input->Name());
|
||||
std::vector<Node*> visited_consumer_nodes;
|
||||
for(auto consumer_node : consumer_nodes){
|
||||
if (visited_nodes.count(consumer_node) != 0){
|
||||
visited_consumer_nodes.push_back(consumer_node);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_consumer_stage_map.count(input) == 0) {
|
||||
input_consumer_stage_map[input] = std::vector<PipelineStageNodeGroup>{
|
||||
PipelineStageNodeGroup(stage, visited_consumer_nodes)};
|
||||
} else {
|
||||
input_consumer_stage_map[input].push_back({stage, visited_consumer_nodes});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& entry : input_consumer_stage_map) {
|
||||
// If any initializer is shared, handle the logic of passing it from the first seen stage all
|
||||
// the way to last seen stage.
|
||||
if (entry.second.size() > 1) {
|
||||
ORT_RETURN_IF_ERROR(AddPassthroughInitializer(graph,
|
||||
entry.first, // initializer node_arg
|
||||
entry.second, // initializer consumer node groups
|
||||
send_nodes,
|
||||
recv_nodes));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// split the graph into disconnected subgraph based on provided CutInfo
|
||||
common::Status SplitGraph(Graph& graph,
|
||||
std::vector<TrainingSession::TrainingConfiguration::CutInfo> split_edge_groups,
|
||||
|
|
@ -869,30 +1075,30 @@ common::Status SplitGraph(Graph& graph,
|
|||
auto cut_index_str = std::to_string(index);
|
||||
// add input node_arg and initializer for send/recv
|
||||
AddNewScalarNodeArgAndInitializer<bool>(graph,
|
||||
"send_input_signal" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
true, /* initializer data */
|
||||
send_input_args,
|
||||
new_input_names);
|
||||
"send_input_signal" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
true, /* initializer data */
|
||||
send_input_args,
|
||||
new_input_names);
|
||||
AddNewScalarNodeArgAndInitializer<bool>(graph,
|
||||
"recv_input_signal" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
true, /* initializer data */
|
||||
recv_input_args,
|
||||
new_input_names);
|
||||
"recv_input_signal" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
true, /* initializer data */
|
||||
recv_input_args,
|
||||
new_input_names);
|
||||
|
||||
AddNewScalarNodeArgAndInitializer<size_t>(graph,
|
||||
"send_dst_rank" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
index + 1, /* initializer data */
|
||||
send_input_args,
|
||||
new_input_names);
|
||||
"send_dst_rank" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
index + 1, /* initializer data */
|
||||
send_input_args,
|
||||
new_input_names);
|
||||
AddNewScalarNodeArgAndInitializer<size_t>(graph,
|
||||
"recv_src_rank" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
index, /* initializer data */
|
||||
recv_input_args,
|
||||
new_input_names);
|
||||
"recv_src_rank" + cut_index_str,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_INT64,
|
||||
index, /* initializer data */
|
||||
recv_input_args,
|
||||
new_input_names);
|
||||
// add output node_arg for send/recv
|
||||
AddNewNodeArg(graph, "send_output_signal" + cut_index_str, ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
send_output_args, new_output_names);
|
||||
|
|
@ -948,6 +1154,7 @@ common::Status SplitGraph(Graph& graph,
|
|||
if (exiting_updated_node_arg != updated_node_args.end()) {
|
||||
updated_node_arg = exiting_updated_node_arg->second;
|
||||
}
|
||||
assert(updated_node_arg);
|
||||
|
||||
send_input_args.push_back(updated_node_arg);
|
||||
|
||||
|
|
@ -955,7 +1162,7 @@ common::Status SplitGraph(Graph& graph,
|
|||
|
||||
element_types.add_ints(static_cast<int64_t>(dtype));
|
||||
|
||||
auto& new_receive_output = CreateNodeArg(graph, updated_node_arg);
|
||||
auto& new_receive_output = CreateNodeArg(graph, *updated_node_arg);
|
||||
const auto old_shape = *(updated_node_arg->Shape());
|
||||
new_receive_output.SetShape(old_shape);
|
||||
recv_output_args.push_back(&new_receive_output);
|
||||
|
|
@ -973,7 +1180,7 @@ common::Status SplitGraph(Graph& graph,
|
|||
// deal with updating the consumer's input node_args
|
||||
std::vector<Node*> consumer_nodes;
|
||||
if (id.consumer_nodes.has_value()) {
|
||||
for(auto& consumer_node_id : id.consumer_nodes.value()){
|
||||
for (auto& consumer_node_id : id.consumer_nodes.value()) {
|
||||
consumer_nodes.push_back(graph.GetMutableProducerNode(consumer_node_id));
|
||||
}
|
||||
} else {
|
||||
|
|
@ -1019,68 +1226,17 @@ common::Status SplitGraph(Graph& graph,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FindAllConnectedNodes(Graph& graph,
|
||||
const Node* node,
|
||||
std::vector<const Node*>& connected_nodes,
|
||||
std::set<const NodeArg*>& connected_inputs,
|
||||
std::set<const NodeArg*>& connected_outputs) {
|
||||
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
|
||||
node->InputDefs(),
|
||||
[&](const NodeArg& node_arg, size_t /*index*/) {
|
||||
if (graph.IsInputsIncludingInitializers(&node_arg)) {
|
||||
connected_inputs.insert(&node_arg);
|
||||
} else {
|
||||
const Node* producer_node = graph.GetProducerNode(node_arg.Name());
|
||||
if (producer_node == nullptr) {
|
||||
// got nullptr as producer node. This could be because the input is a constant op which will be optimized
|
||||
// away. Print out this information and continue.
|
||||
LOGS_DEFAULT(WARNING) << "Cannot find producer node for node_arg: " << node_arg.Name() << ". Skipping this node.";
|
||||
} else {
|
||||
connected_nodes.push_back(producer_node);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
|
||||
node->OutputDefs(),
|
||||
[&](const NodeArg& node_arg, size_t /*index*/) {
|
||||
if (!graph.IsOutput(&node_arg)) {
|
||||
std::vector<const Node*> consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
|
||||
connected_nodes.insert(std::end(connected_nodes), consumer_nodes.begin(), consumer_nodes.end());
|
||||
|
||||
} else {
|
||||
connected_outputs.insert(&node_arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// traverse the graph from start_node to get the set of nodes contains in this disconnected subgraph
|
||||
common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
|
||||
std::queue<const Node*> node_queue;
|
||||
node_queue.push(start_node);
|
||||
|
||||
std::set<const Node*> visited_nodes;
|
||||
std::set<const NodeArg*> visited_inputs;
|
||||
std::set<const NodeArg*> visited_outputs;
|
||||
common::Status GenerateSubgraph(Graph& graph, Node* start_node) {
|
||||
assert(start_node);
|
||||
std::set<Node*> visited_nodes;
|
||||
std::set<NodeArg*> visited_inputs;
|
||||
std::set<NodeArg*> visited_outputs;
|
||||
|
||||
// BFS graph traverse
|
||||
while (!node_queue.empty()) {
|
||||
auto node = node_queue.front();
|
||||
node_queue.pop();
|
||||
if (visited_nodes.count(node) == 0) {
|
||||
visited_nodes.insert(node);
|
||||
std::vector<const Node*> connected_nodes;
|
||||
ORT_THROW_IF_ERROR(FindAllConnectedNodes(graph, node, connected_nodes, visited_inputs, visited_outputs));
|
||||
TraverseGraphWithConnectedElement(graph, start_node,
|
||||
visited_nodes, visited_inputs, visited_outputs);
|
||||
|
||||
for (auto n : connected_nodes) {
|
||||
ORT_ENFORCE(n!=nullptr, "Found nullptr in searching for connected nodes");
|
||||
node_queue.push(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::set<NodeIndex> visited_node_index;
|
||||
for (auto n : visited_nodes) {
|
||||
visited_node_index.insert(n->Index());
|
||||
|
|
@ -1090,8 +1246,8 @@ common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
|
|||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
// reverse iterate the nodes in tolopogical order, and delete those not visited
|
||||
for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++){
|
||||
if (visited_node_index.count(*it)==0){
|
||||
for (auto it = node_topology_list.rbegin(); it != node_topology_list.rend(); it++) {
|
||||
if (visited_node_index.count(*it) == 0) {
|
||||
graph.RemoveNode(*it);
|
||||
}
|
||||
}
|
||||
|
|
@ -1131,18 +1287,60 @@ Status ApplyPipelinePartitionToMainGraph(
|
|||
std::vector<Node *> send_nodes, recv_nodes;
|
||||
send_nodes.reserve(split_count);
|
||||
recv_nodes.reserve(split_count);
|
||||
|
||||
// Split the graph by cutting edges specified in cut_info. After this function, the graph will be
|
||||
// composed of several disconnected partitions.
|
||||
ORT_RETURN_IF_ERROR(SplitGraph(graph, cut_info, send_nodes, recv_nodes));
|
||||
|
||||
if (send_nodes.size() != split_count || recv_nodes.size() != split_count) {
|
||||
ORT_THROW("Split error: not all cut has Send and Recv inserted. Send node count: ",
|
||||
send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count);
|
||||
send_nodes.size(), ", Recv node count: ", recv_nodes.size(), ", split count: ", split_count);
|
||||
}
|
||||
|
||||
// Check to see if there are any initializers that is being shared between different partitions. If there
|
||||
// is, keep the initializer in the first seen partition and have it pass through by send/recv to the others.
|
||||
ORT_RETURN_IF_ERROR(HandleSharedInitializer(graph, send_nodes, recv_nodes));
|
||||
|
||||
// Now remove the partitions that are not tie to the current pipeline stage and generate the sub-graph.
|
||||
if (pipeline_stage_id < split_count) {
|
||||
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, send_nodes[pipeline_stage_id]));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(GenerateSubgraph(graph, recv_nodes.back()));
|
||||
}
|
||||
|
||||
// Post check to ensure the curent partition is correct and matches with Send/Recv nodes inserted during split.
|
||||
Node* send_node{nullptr};
|
||||
Node* recv_node{nullptr};
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Send") {
|
||||
send_node = &node;
|
||||
} else if (node.OpType() == "Recv") {
|
||||
recv_node = &node;
|
||||
}
|
||||
}
|
||||
|
||||
if (pipeline_stage_id == 0){
|
||||
// For the first stage, there should be no recv node, and the send node contained in graph should match the first
|
||||
// send_node inserted during split.
|
||||
ORT_ENFORCE(recv_node == nullptr, "Error: first stage contains Recv node in forward pass.");
|
||||
ORT_ENFORCE(send_node == send_nodes[0],
|
||||
"Error: first stage doesn't contain the right Send node. Possibly CutInfo data is wrong.");
|
||||
}
|
||||
else if (pipeline_stage_id == split_count){
|
||||
// For the last stage, there should be no send node, and the recv node contained in graph should match the last
|
||||
// recv_node inserted during split.
|
||||
ORT_ENFORCE(recv_node == recv_nodes.back(),
|
||||
"Error: last stage doesn't contain the right Recv node. Possibly CutInfo data is wrong.");
|
||||
ORT_ENFORCE(send_node == nullptr, "Error: last stage contains Send node in forward pass.");
|
||||
} else {
|
||||
// For stages in the middle, i-th stage should contain recv node that matches the (i-1)-th inserted recv node, and the i-th
|
||||
// inserted send node.
|
||||
ORT_ENFORCE(recv_node == recv_nodes[pipeline_stage_id - 1],
|
||||
"Error: stage ", pipeline_stage_id, " doesn't contain the right Recv node. Possibly CutInfo data is wrong.");
|
||||
ORT_ENFORCE(send_node == send_nodes[pipeline_stage_id],
|
||||
"Error: stage ", pipeline_stage_id, " doesn't contain the right Send node. Possibly CutInfo data is wrong.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1117,7 +1117,114 @@ void RetrieveSendRecvOperators(
|
|||
}
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
|
||||
PathString GenerateFileNameWithIndex(const PathString& base_str, int index, const PathString& file_suffix) {
|
||||
Path p;
|
||||
ORT_ENFORCE(Path::Parse(base_str, p).IsOK());
|
||||
return p.ConcatIndex(index).Concat(file_suffix).ToPathString();
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_bert_tiny) {
|
||||
const auto model_path = ORT_TSTR("testdata/bert_toy_optimized.onnx");
|
||||
|
||||
const size_t total_partition_count = 3;
|
||||
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
|
||||
pipe.do_partition = true;
|
||||
|
||||
// cut model in 3 partitions
|
||||
TrainingSession::TrainingConfiguration::CutInfo cut0 = {
|
||||
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("326"),
|
||||
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("103", {"413", "529"})};
|
||||
|
||||
TrainingSession::TrainingConfiguration::CutInfo cut1 = {
|
||||
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("558"),
|
||||
onnxruntime::training::TrainingSession::TrainingConfiguration::CutEdge("103", {"645"})};
|
||||
|
||||
pipe.cut_list.emplace_back(cut0);
|
||||
pipe.cut_list.emplace_back(cut1);
|
||||
|
||||
TrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mixed_precision_config{};
|
||||
mixed_precision_config.use_fp16_initializers = true;
|
||||
|
||||
// 2 test variations - full precision and mixed precision
|
||||
const std::vector<bool> test_with_fp32{true, false};
|
||||
for (auto is_fp32 : test_with_fp32) {
|
||||
// graph is partitioned into 3 parts.
|
||||
for (int i = 0; i < static_cast<int>(total_partition_count); ++i) {
|
||||
|
||||
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
||||
if (i == static_cast<int>(total_partition_count - 1)) {
|
||||
config.loss_function_config = TrainingSession::TrainingConfiguration::LossFunctionConfiguration{};
|
||||
config.loss_function_config.value().loss_function_info =
|
||||
LossFunctionInfo(OpDef("BertLoss", kOnnxDomain),
|
||||
"total_loss",
|
||||
{/*prediction_masked_lm*/ "prediction_scores",
|
||||
/*prediction_next_sentence*/ "seq_relationship_score",
|
||||
/*masked_lm_positions*/ "masked_lm_positions",
|
||||
/*masked_lm_ids*/ "masked_lm_ids",
|
||||
/*masked_lm_weights*/ "masked_lm_weights",
|
||||
/*next_sentence_labels*/ "next_sentence_labels",
|
||||
/*mlm_loss*/ "mlm_loss",
|
||||
/*nsp_loss*/ "nsp_loss"});
|
||||
}
|
||||
|
||||
// Add weight_names_to_not_train to avoid generating backward graph on those tensor
|
||||
config.weight_names_to_not_train = {
|
||||
"position_01", // Slice's dat input
|
||||
"op_min_ends_expand_10", //op_min_ends_expand_10
|
||||
};
|
||||
|
||||
config.pipeline_config = pipe;
|
||||
config.distributed_config.world_rank = i;
|
||||
config.distributed_config.world_size = total_partition_count;
|
||||
config.distributed_config.local_rank = i;
|
||||
config.distributed_config.local_size = total_partition_count;
|
||||
config.distributed_config.data_parallel_size = 1;
|
||||
config.distributed_config.horizontal_parallel_size = 1;
|
||||
config.distributed_config.pipeline_parallel_size = total_partition_count;
|
||||
config.model_with_training_graph_path = output_file;
|
||||
|
||||
if (!is_fp32) {
|
||||
config.mixed_precision_config = mixed_precision_config;
|
||||
}
|
||||
|
||||
PathString backprop_model_file;
|
||||
Status status = BuildBackPropGraph(model_path, config, backprop_model_file);
|
||||
ASSERT_TRUE(status.IsOK()) << status << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
|
||||
|
||||
// Skip the re-load for mixed-precision case. This model contains grad op that has function body,
|
||||
// which takes a const tensor input. Const cast for input in function body won't be saved in the output
|
||||
// model so reload will run into error.
|
||||
// For the purpose of testing mixed-precision, BuildBackPropGraph above will be sufficient to verify the
|
||||
// partition logic and validate the graph.
|
||||
if (is_fp32) {
|
||||
std::shared_ptr<Model> model;
|
||||
// Ensure the partitioned model load.
|
||||
status = Model::Load(backprop_model_file, model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK()) << status << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
|
||||
|
||||
// verify the first stage contains word embedding as input and the last stage doesn't
|
||||
auto model_proto = model->ToProto();
|
||||
const auto& graph_proto = model_proto.graph();
|
||||
|
||||
bool found_word_embedding = false;
|
||||
for (auto& tensor : graph_proto.initializer()) {
|
||||
if (tensor.name() == "bert.embeddings.word_embeddings.weight") {
|
||||
found_word_embedding = true;
|
||||
}
|
||||
}
|
||||
if (i == 0) {
|
||||
ASSERT_TRUE(found_word_embedding) << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
|
||||
} else {
|
||||
ASSERT_FALSE(found_word_embedding) << " (is_fp32 = " << is_fp32 << ", stage = " << i << ").\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_MLP) {
|
||||
auto model_uri = ORIGINAL_MODEL_PATH;
|
||||
|
||||
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
|
||||
|
|
@ -1137,14 +1244,10 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
|
|||
for(auto is_fp32 : test_with_fp32) {
|
||||
// graph is partitioned into 3 parts.
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
#ifdef _WIN32
|
||||
auto surfix = std::to_wstring(i);
|
||||
#else
|
||||
auto surfix = std::to_string(i);
|
||||
#endif
|
||||
PathString output_file = ORT_TSTR("pipeline_partition_") + surfix + ORT_TSTR("_back.onnx");
|
||||
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
|
||||
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
||||
config.pipeline_config = pipe;
|
||||
config.distributed_config.world_rank = i;
|
||||
config.distributed_config.world_size = 3;
|
||||
|
|
@ -1178,9 +1281,62 @@ TEST(GradientGraphBuilderTest, PipelineOnlinePartition) {
|
|||
}
|
||||
}
|
||||
|
||||
Status RunOnlinePartition(const std::vector<TrainingSession::TrainingConfiguration::CutInfo>& cut_list,
|
||||
int pipeline_stage_size,
|
||||
std::set<int> status_check_stages = {}) {
|
||||
auto model_uri = ORIGINAL_MODEL_PATH;
|
||||
|
||||
TrainingSession::TrainingConfiguration::PipelineConfiguration pipe{};
|
||||
pipe.do_partition = true;
|
||||
pipe.cut_list = cut_list;
|
||||
|
||||
for (int i = 0; i < pipeline_stage_size; ++i) {
|
||||
PathString output_file = GenerateFileNameWithIndex(ORT_TSTR("pipeline_partition_"), i, ORT_TSTR("_back.onnx"));
|
||||
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
config.pipeline_config = pipe;
|
||||
|
||||
config.distributed_config.world_rank = i;
|
||||
config.distributed_config.world_size = pipeline_stage_size;
|
||||
config.distributed_config.local_rank = i;
|
||||
config.distributed_config.local_size = pipeline_stage_size;
|
||||
config.distributed_config.data_parallel_size = 1;
|
||||
config.distributed_config.horizontal_parallel_size = 1;
|
||||
config.distributed_config.pipeline_parallel_size = pipeline_stage_size;
|
||||
config.model_with_training_graph_path = output_file;
|
||||
|
||||
PathString backprop_model_file;
|
||||
if (status_check_stages.count(i) > 0) {
|
||||
auto status = BuildBackPropGraph(model_uri, config, backprop_model_file);
|
||||
EXPECT_FALSE(status.IsOK());
|
||||
} else {
|
||||
EXPECT_THROW(BuildBackPropGraph(model_uri, config, backprop_model_file), OnnxRuntimeException);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST(GradientGraphBuilderTest, PipelineOnlinePartition_Invalid_Input) {
|
||||
using CutEdge = TrainingSession::TrainingConfiguration::CutEdge;
|
||||
using CutInfo = TrainingSession::TrainingConfiguration::CutInfo;
|
||||
|
||||
// Test with invalid cut edge
|
||||
TrainingSession::TrainingConfiguration::CutInfo invalid_cut_edge = {TrainingSession::TrainingConfiguration::CutEdge("3")};
|
||||
ASSERT_STATUS_OK(RunOnlinePartition(std::vector<TrainingSession::TrainingConfiguration::CutInfo>{invalid_cut_edge}, 2 /* pipeline_stage_size */));
|
||||
|
||||
// Test mis-matched cut list with stage size
|
||||
TrainingSession::TrainingConfiguration::CutInfo cut_edge = {TrainingSession::TrainingConfiguration::CutEdge("T3")};
|
||||
ASSERT_STATUS_OK(RunOnlinePartition(std::vector<TrainingSession::TrainingConfiguration::CutInfo>{cut_edge}, 3 /* pipeline_stage_size */));
|
||||
|
||||
// Test unordered cut_info list
|
||||
CutInfo cut0 = {CutEdge("T3")};
|
||||
CutInfo cut1 = {CutEdge("T6")};
|
||||
ASSERT_STATUS_OK(RunOnlinePartition(std::vector<CutInfo>{cut1, cut0}, 3 /* pipeline_stage_size */, {0, 2} /* status_check_stages */));
|
||||
}
|
||||
|
||||
// verify pipeline config can load and gradient graph can construct.
|
||||
TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
||||
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
|
||||
PathString filename_base = ORT_TSTR("testdata/test_training_model_");
|
||||
|
||||
auto load_and_check_gradient_graph = [](int stageIdx, PathString& input_file, PathString& output_file) {
|
||||
auto config = MakeBasicTrainingConfig();
|
||||
|
|
@ -1286,13 +1442,9 @@ TEST(GradientGraphBuilderTest, TrainingSession_PipelineTransform_base) {
|
|||
};
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
#ifdef _WIN32
|
||||
auto surfix = std::to_wstring(i);
|
||||
#else
|
||||
auto surfix = std::to_string(i);
|
||||
#endif
|
||||
PathString input_file = filename_base + surfix + ORT_TSTR(".onnx");
|
||||
PathString output_file = filename_base + surfix + ORT_TSTR("_back.onnx");
|
||||
PathString input_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR(".onnx"));
|
||||
PathString output_file = GenerateFileNameWithIndex(filename_base, i, ORT_TSTR("_back.onnx"));
|
||||
|
||||
load_and_check_gradient_graph(i, input_file, output_file);
|
||||
}
|
||||
}
|
||||
|
|
@ -1358,12 +1510,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
|
|||
|
||||
std::vector<PathString> sub_model_files(num_subs);
|
||||
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
|
||||
#ifdef _WIN32
|
||||
auto sub_id_str = std::to_wstring(sub_id);
|
||||
#else
|
||||
auto sub_id_str = std::to_string(sub_id);
|
||||
#endif
|
||||
sub_model_files[sub_id] = ORT_TSTR("sub_") + sub_id_str + ORT_TSTR(".onnx");
|
||||
sub_model_files[sub_id] = GenerateFileNameWithIndex(ORT_TSTR("sub_"), sub_id, ORT_TSTR(".onnx"));
|
||||
}
|
||||
|
||||
PipelineSplitter splitter;
|
||||
|
|
@ -1383,12 +1530,7 @@ TEST(GradientGraphBuilderTest, TrainingSession_WithPipeline) {
|
|||
for (size_t sub_id = 0; sub_id < num_subs; ++sub_id) {
|
||||
auto& sub_sess = subs[sub_id];
|
||||
sub_sess.so.enable_profiling = true;
|
||||
#ifdef _WIN32
|
||||
auto sub_id_str = std::to_wstring(sub_id);
|
||||
#else
|
||||
auto sub_id_str = std::to_string(sub_id);
|
||||
#endif
|
||||
sub_sess.so.profile_file_prefix = ORT_TSTR("pipeline") + sub_id_str;
|
||||
sub_sess.so.profile_file_prefix = GenerateFileNameWithIndex(ORT_TSTR("pipeline"), static_cast<int>(sub_id), ORT_TSTR(""));
|
||||
|
||||
sub_sess.run_options.run_log_verbosity_level = sub_sess.so.session_log_verbosity_level;
|
||||
sub_sess.run_options.run_tag = sub_sess.so.session_logid;
|
||||
|
|
|
|||
Loading…
Reference in a new issue