mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Enable static memory planning for pipeline. (#4204)
* Enable static memory planning for pipeline. 1. We fix a bug when resolving symbolic shape for scalars. 2. We pass the original inputs to all pipeline stages so that the symbolic shapes can be resolved. * Further Improvements 1. Address comments. 2. Further reduce activation size by ~50% when pipeline is on. This is done by removing all but one gradient tensor from the last RecordEvent in the backward pass. * Address a comment * Fix Windows build
This commit is contained in:
parent
b377266eb3
commit
de9da123cf
4 changed files with 113 additions and 14 deletions
|
|
@ -55,6 +55,10 @@ class NodeArg {
|
|||
@returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */
|
||||
const ONNX_NAMESPACE::TensorShapeProto* Shape() const;
|
||||
|
||||
/** Return an indicator.
|
||||
@returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */
|
||||
bool HasTensorOrScalarShape() const;
|
||||
|
||||
/** Sets the shape.
|
||||
@remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called,
|
||||
as the shape information is stored as part of TypeProto. */
|
||||
|
|
|
|||
|
|
@ -190,16 +190,26 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
|
|||
|
||||
#ifdef ENABLE_TRAINING
|
||||
namespace {
|
||||
Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, TensorShape>& feeds, std::unordered_map<std::string, int64_t>& out) {
|
||||
Status ResolveDimParams(const GraphViewer& graph,
|
||||
const std::map<std::string, TensorShape>& feeds,
|
||||
std::unordered_map<std::string, int64_t>& out) {
|
||||
for (const auto* input : graph.GetInputs()) {
|
||||
auto* shape = input->Shape();
|
||||
auto it = feeds.find(input->Name());
|
||||
if (it == feeds.end())
|
||||
return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() + " is not found in the feed list, unable to resolve the value for dynamic shape.");
|
||||
if (!shape || shape->dim_size() != static_cast<int>(it->second.NumDimensions()))
|
||||
if (it == feeds.end()) {
|
||||
return Status(ONNXRUNTIME, FAIL,
|
||||
"Graph input " + input->Name() +
|
||||
" is not found in the feed list, unable to resolve the value for dynamic shape.");
|
||||
}
|
||||
if (it->second.NumDimensions() == 0 && !shape) {
|
||||
// This is a scalar, which has nothing to do with symbolic shapes.
|
||||
continue;
|
||||
}
|
||||
if (!shape || shape->dim_size() != static_cast<int>(it->second.NumDimensions())) {
|
||||
return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() +
|
||||
"'s shape is not present or its shape doesn't match feed's shape."
|
||||
"Unable to resolve the value for dynamic shape");
|
||||
"'s shape is not present or its shape doesn't match feed's shape."
|
||||
"Unable to resolve the value for dynamic shape");
|
||||
}
|
||||
for (int k = 0, end = shape->dim_size(); k < end; ++k) {
|
||||
if (shape->dim()[k].has_dim_param()) {
|
||||
out.insert({shape->dim()[k].dim_param(), it->second.GetDims()[k]});
|
||||
|
|
@ -320,7 +330,7 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<
|
|||
void SessionState::ResolveMemoryPatternFlag() {
|
||||
if (enable_mem_pattern_) {
|
||||
for (auto* input : graph_viewer_->GetInputs()) {
|
||||
if (!input->Shape()) {
|
||||
if (!input->HasTensorOrScalarShape()) {
|
||||
enable_mem_pattern_ = false;
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,6 +172,26 @@ const TensorShapeProto* NodeArg::Shape() const {
|
|||
}
|
||||
}
|
||||
|
||||
bool NodeArg::HasTensorOrScalarShape() const {
|
||||
const TypeProto* type = TypeAsProto();
|
||||
if (!type) return false;
|
||||
const auto type_case = type->value_case();
|
||||
switch (type_case) {
|
||||
case TypeProto::kTensorType:
|
||||
case TypeProto::kSparseTensorType:
|
||||
// Standard tensor has a valid shape field while
|
||||
// scalar's shape is empty. Thus, we don't need to
|
||||
// check shape here.
|
||||
return true;
|
||||
case TypeProto::kSequenceType:
|
||||
case TypeProto::kMapType:
|
||||
case TypeProto::kOpaqueType:
|
||||
case TypeProto::VALUE_NOT_SET:
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void NodeArg::SetShape(const TensorShapeProto& shape) {
|
||||
const auto type_case = node_arg_info_.type().value_case();
|
||||
switch (type_case) {
|
||||
|
|
|
|||
|
|
@ -41,9 +41,12 @@ void AddNewNodeArg(Graph& graph,
|
|||
new_node_args.push_back(&new_node_arg);
|
||||
}
|
||||
|
||||
// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent
|
||||
// backward node's input.
|
||||
void FindLeafNodes(Graph& graph, std::vector<NodeArg*>& input_args) {
|
||||
// Gradient graph can contain some dangling leaf nodes. This function collects
|
||||
// their first output using the returned vector.
|
||||
std::vector<NodeArg*> FindBackwardLeafNodes(Graph& graph) {
|
||||
// leaf_node_args[i] is the i-th leaf node's first output in the backward
|
||||
// pass.
|
||||
std::vector<NodeArg*> leaf_node_args;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (!IsBackward(node)) {
|
||||
// only check backward node
|
||||
|
|
@ -59,11 +62,52 @@ void FindLeafNodes(Graph& graph, std::vector<NodeArg*>& input_args) {
|
|||
}
|
||||
}
|
||||
if (!find_consumer_nodes && outputs.size() > 0) {
|
||||
input_args.push_back(outputs[0]);
|
||||
leaf_node_args.push_back(outputs[0]);
|
||||
}
|
||||
}
|
||||
|
||||
return leaf_node_args;
|
||||
};
|
||||
|
||||
// This function converts tensor NodeArg to a boolean scalar so that last
|
||||
// backward RecordEvent doesn't block the early release of large gradient
|
||||
// tensors. If we connect gradient tensors directly to that RecordEvent,
|
||||
// we need a memory block (as large as a whole model) to store gradient
|
||||
// for each trainable tensor until the end of backward pass.
|
||||
//
|
||||
// The newly created boolean scalar may be appended to signal_args. If
|
||||
// signal_args is empty, the source of signal_args[i] would be tensor_args[i].
|
||||
void ConvertTensorToBoolSignal(
|
||||
Graph& graph,
|
||||
const std::vector<NodeArg*>& tensor_args,
|
||||
std::vector<NodeArg*>& signal_args) {
|
||||
|
||||
for (auto tensor_arg: tensor_args) {
|
||||
// Declare the scalar signal this "tensor_arg" will be converted into.
|
||||
auto signal_arg = &CreateTypedNodeArg(
|
||||
graph,
|
||||
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
|
||||
"signal_" + tensor_arg->Name()
|
||||
);
|
||||
|
||||
// Add the new scalar to user-specified vector.
|
||||
signal_args.push_back(signal_arg);
|
||||
|
||||
// Add tensor-to-scalar conversion node.
|
||||
const auto name = graph.GenerateNodeName("tensor_to_scalar_signal");
|
||||
std::vector<NodeArg*> input_args{tensor_arg};
|
||||
std::vector<NodeArg*> output_args{signal_arg};
|
||||
graph.AddNode(
|
||||
name,
|
||||
"Group",
|
||||
"",
|
||||
input_args,
|
||||
output_args,
|
||||
nullptr,
|
||||
kMSDomain);
|
||||
}
|
||||
}
|
||||
|
||||
// Return mirror variables for node_arg with a different name.
|
||||
NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
|
||||
const auto& new_name = graph.GenerateNodeArgName(base_arg->Name());
|
||||
|
|
@ -133,7 +177,19 @@ Node* AddBackwardRecord(Graph& graph,
|
|||
std::begin(backward_send->MutableOutputDefs()),
|
||||
std::end(backward_send->MutableOutputDefs()));
|
||||
}
|
||||
FindLeafNodes(graph, input_args);
|
||||
|
||||
// Find all leaf nodes' frist inputs. They are used togehter as control edges
|
||||
// to determine if backward pass is finished.
|
||||
auto backward_leaf_node_args = FindBackwardLeafNodes(graph);
|
||||
|
||||
// For each leaf tensor in the backward pass, we use "Group" operator to
|
||||
// convert it to a boolean scalar so that the original leaf's memory can be
|
||||
// released earlier.
|
||||
|
||||
// TODO: use full list instead of the first element after changining
|
||||
// topological sort to depth-first from inputs.
|
||||
std::vector<NodeArg*> sub_backward_leaf_node_args{backward_leaf_node_args[0]};
|
||||
ConvertTensorToBoolSignal(graph, backward_leaf_node_args, input_args);
|
||||
|
||||
// Optimizer will be added after applying pipeline transformer. To support partial graph evaluation,
|
||||
// the added Record backward op will have its first passthrough input as output.
|
||||
|
|
@ -1040,8 +1096,17 @@ common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
|
|||
}
|
||||
}
|
||||
|
||||
// update the grah with only visited inputs and outputs
|
||||
graph.SetInputs({visited_inputs.begin(), visited_inputs.end()});
|
||||
// If the following line is uncommented, middle and last pipeline stages may
|
||||
// have unresolved symbolic shapes. The reason is that some symbolic shapes
|
||||
// are defined for the original inputs, if original inputs are removed, we
|
||||
// loss the hit to resolve symbolic shapes. For example, if an original
|
||||
// input's shape is [batch, sequence, 1024], that input should be provided as
|
||||
// a feed to all pipeline stages. Otherwise, we don't know the actual values
|
||||
// of "batch" and "sequence".
|
||||
//
|
||||
// graph.SetInputs({visited_inputs.begin(), visited_inputs.end()});
|
||||
|
||||
// update the grah with only visited outputs
|
||||
graph.SetOutputs({visited_outputs.begin(), visited_outputs.end()});
|
||||
graph.SetGraphResolveNeeded();
|
||||
graph.SetGraphProtoSyncNeeded();
|
||||
|
|
|
|||
Loading…
Reference in a new issue