Enable static memory planning for pipeline. (#4204)

* Enable static memory planning for pipeline.
1. We fix a bug when resolving symbolic shape for scalars.
2. We pass the original inputs to all pipeline stages so that
   the symbolic shapes can be resolved.

* Further Improvements
1. Address comments.
2. Further reduce activation size by ~50% when pipeline is on.
   This is done by removing all but one gradient tensor from the last
   RecordEvent in the backward pass.

* Address a comment

* Fix Windows build
This commit is contained in:
Wei-Sheng Chin 2020-06-12 21:43:50 -07:00 committed by GitHub
parent b377266eb3
commit de9da123cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 14 deletions

View file

@ -55,6 +55,10 @@ class NodeArg {
@returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */
const ONNX_NAMESPACE::TensorShapeProto* Shape() const;
/** Return an indicator.
@returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */
bool HasTensorOrScalarShape() const;
/** Sets the shape.
@remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called,
as the shape information is stored as part of TypeProto. */

View file

@ -190,16 +190,26 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
#ifdef ENABLE_TRAINING
namespace {
Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, TensorShape>& feeds, std::unordered_map<std::string, int64_t>& out) {
Status ResolveDimParams(const GraphViewer& graph,
const std::map<std::string, TensorShape>& feeds,
std::unordered_map<std::string, int64_t>& out) {
for (const auto* input : graph.GetInputs()) {
auto* shape = input->Shape();
auto it = feeds.find(input->Name());
if (it == feeds.end())
return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() + " is not found in the feed list, unable to resolve the value for dynamic shape.");
if (!shape || shape->dim_size() != static_cast<int>(it->second.NumDimensions()))
if (it == feeds.end()) {
return Status(ONNXRUNTIME, FAIL,
"Graph input " + input->Name() +
" is not found in the feed list, unable to resolve the value for dynamic shape.");
}
if (it->second.NumDimensions() == 0 && !shape) {
// This is a scalar, which has nothing to do with symbolic shapes.
continue;
}
if (!shape || shape->dim_size() != static_cast<int>(it->second.NumDimensions())) {
return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() +
"'s shape is not present or its shape doesn't match feed's shape."
"Unable to resolve the value for dynamic shape");
"'s shape is not present or its shape doesn't match feed's shape."
"Unable to resolve the value for dynamic shape");
}
for (int k = 0, end = shape->dim_size(); k < end; ++k) {
if (shape->dim()[k].has_dim_param()) {
out.insert({shape->dim()[k].dim_param(), it->second.GetDims()[k]});
@ -320,7 +330,7 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector<
void SessionState::ResolveMemoryPatternFlag() {
if (enable_mem_pattern_) {
for (auto* input : graph_viewer_->GetInputs()) {
if (!input->Shape()) {
if (!input->HasTensorOrScalarShape()) {
enable_mem_pattern_ = false;
break;
}

View file

@ -172,6 +172,26 @@ const TensorShapeProto* NodeArg::Shape() const {
}
}
bool NodeArg::HasTensorOrScalarShape() const {
const TypeProto* type = TypeAsProto();
if (!type) return false;
const auto type_case = type->value_case();
switch (type_case) {
case TypeProto::kTensorType:
case TypeProto::kSparseTensorType:
// Standard tensor has a valid shape field while
// scalar's shape is empty. Thus, we don't need to
// check shape here.
return true;
case TypeProto::kSequenceType:
case TypeProto::kMapType:
case TypeProto::kOpaqueType:
case TypeProto::VALUE_NOT_SET:
default:
return false;
}
}
void NodeArg::SetShape(const TensorShapeProto& shape) {
const auto type_case = node_arg_info_.type().value_case();
switch (type_case) {

View file

@ -41,9 +41,12 @@ void AddNewNodeArg(Graph& graph,
new_node_args.push_back(&new_node_arg);
}
// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent
// backward node's input.
void FindLeafNodes(Graph& graph, std::vector<NodeArg*>& input_args) {
// Gradient graph can contain some dangling leaf nodes. This function collects
// their first output using the returned vector.
std::vector<NodeArg*> FindBackwardLeafNodes(Graph& graph) {
// leaf_node_args[i] is the i-th leaf node's first output in the backward
// pass.
std::vector<NodeArg*> leaf_node_args;
for (auto& node : graph.Nodes()) {
if (!IsBackward(node)) {
// only check backward node
@ -59,11 +62,52 @@ void FindLeafNodes(Graph& graph, std::vector<NodeArg*>& input_args) {
}
}
if (!find_consumer_nodes && outputs.size() > 0) {
input_args.push_back(outputs[0]);
leaf_node_args.push_back(outputs[0]);
}
}
return leaf_node_args;
};
// This function converts tensor NodeArg to a boolean scalar so that last
// backward RecordEvent doesn't block the early release of large gradient
// tensors. If we connect gradient tensors directly to that RecordEvent,
// we need a memory block (as large as a whole model) to store gradient
// for each trainable tensor until the end of backward pass.
//
// The newly created boolean scalar may be appended to signal_args. If
// signal_args is empty, the source of signal_args[i] would be tensor_args[i].
void ConvertTensorToBoolSignal(
Graph& graph,
const std::vector<NodeArg*>& tensor_args,
std::vector<NodeArg*>& signal_args) {
for (auto tensor_arg: tensor_args) {
// Declare the scalar signal this "tensor_arg" will be converted into.
auto signal_arg = &CreateTypedNodeArg(
graph,
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
"signal_" + tensor_arg->Name()
);
// Add the new scalar to user-specified vector.
signal_args.push_back(signal_arg);
// Add tensor-to-scalar conversion node.
const auto name = graph.GenerateNodeName("tensor_to_scalar_signal");
std::vector<NodeArg*> input_args{tensor_arg};
std::vector<NodeArg*> output_args{signal_arg};
graph.AddNode(
name,
"Group",
"",
input_args,
output_args,
nullptr,
kMSDomain);
}
}
// Return mirror variables for node_arg with a different name.
NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) {
const auto& new_name = graph.GenerateNodeArgName(base_arg->Name());
@ -133,7 +177,19 @@ Node* AddBackwardRecord(Graph& graph,
std::begin(backward_send->MutableOutputDefs()),
std::end(backward_send->MutableOutputDefs()));
}
FindLeafNodes(graph, input_args);
// Find all leaf nodes' frist inputs. They are used togehter as control edges
// to determine if backward pass is finished.
auto backward_leaf_node_args = FindBackwardLeafNodes(graph);
// For each leaf tensor in the backward pass, we use "Group" operator to
// convert it to a boolean scalar so that the original leaf's memory can be
// released earlier.
// TODO: use full list instead of the first element after changining
// topological sort to depth-first from inputs.
std::vector<NodeArg*> sub_backward_leaf_node_args{backward_leaf_node_args[0]};
ConvertTensorToBoolSignal(graph, backward_leaf_node_args, input_args);
// Optimizer will be added after applying pipeline transformer. To support partial graph evaluation,
// the added Record backward op will have its first passthrough input as output.
@ -1040,8 +1096,17 @@ common::Status GenerateSubgraph(Graph& graph, const Node* start_node) {
}
}
// update the grah with only visited inputs and outputs
graph.SetInputs({visited_inputs.begin(), visited_inputs.end()});
// If the following line is uncommented, middle and last pipeline stages may
// have unresolved symbolic shapes. The reason is that some symbolic shapes
// are defined for the original inputs, if original inputs are removed, we
// loss the hit to resolve symbolic shapes. For example, if an original
// input's shape is [batch, sequence, 1024], that input should be provided as
// a feed to all pipeline stages. Otherwise, we don't know the actual values
// of "batch" and "sequence".
//
// graph.SetInputs({visited_inputs.begin(), visited_inputs.end()});
// update the grah with only visited outputs
graph.SetOutputs({visited_outputs.begin(), visited_outputs.end()});
graph.SetGraphResolveNeeded();
graph.SetGraphProtoSyncNeeded();