diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 6a71ecd697..461e217b0b 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -55,6 +55,10 @@ class NodeArg { @returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */ const ONNX_NAMESPACE::TensorShapeProto* Shape() const; + /** Return an indicator. + @returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */ + bool HasTensorOrScalarShape() const; + /** Sets the shape. @remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called, as the shape information is stored as part of TypeProto. */ diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 8f32f8b226..09660295c4 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -190,16 +190,26 @@ static int64_t CalculateMemoryPatternsKey(const std::vector& feeds, std::unordered_map& out) { +Status ResolveDimParams(const GraphViewer& graph, + const std::map& feeds, + std::unordered_map& out) { for (const auto* input : graph.GetInputs()) { auto* shape = input->Shape(); auto it = feeds.find(input->Name()); - if (it == feeds.end()) - return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() + " is not found in the feed list, unable to resolve the value for dynamic shape."); - if (!shape || shape->dim_size() != static_cast(it->second.NumDimensions())) + if (it == feeds.end()) { + return Status(ONNXRUNTIME, FAIL, + "Graph input " + input->Name() + + " is not found in the feed list, unable to resolve the value for dynamic shape."); + } + if (it->second.NumDimensions() == 0 && !shape) { + // This is a scalar, which has nothing to do with symbolic shapes. + continue; + } + if (!shape || shape->dim_size() != static_cast(it->second.NumDimensions())) { return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() + - "'s shape is not present or its shape doesn't match feed's shape." - "Unable to resolve the value for dynamic shape"); + "'s shape is not present or its shape doesn't match feed's shape." + "Unable to resolve the value for dynamic shape"); + } for (int k = 0, end = shape->dim_size(); k < end; ++k) { if (shape->dim()[k].has_dim_param()) { out.insert({shape->dim()[k].dim_param(), it->second.GetDims()[k]}); @@ -320,7 +330,7 @@ const MemoryPatternGroup* SessionState::GetMemoryPatternGroup(const std::vector< void SessionState::ResolveMemoryPatternFlag() { if (enable_mem_pattern_) { for (auto* input : graph_viewer_->GetInputs()) { - if (!input->Shape()) { + if (!input->HasTensorOrScalarShape()) { enable_mem_pattern_ = false; break; } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 0e11eb1ba3..8d6fc7457a 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -172,6 +172,26 @@ const TensorShapeProto* NodeArg::Shape() const { } } +bool NodeArg::HasTensorOrScalarShape() const { + const TypeProto* type = TypeAsProto(); + if (!type) return false; + const auto type_case = type->value_case(); + switch (type_case) { + case TypeProto::kTensorType: + case TypeProto::kSparseTensorType: + // Standard tensor has a valid shape field while + // scalar's shape is empty. Thus, we don't need to + // check shape here. + return true; + case TypeProto::kSequenceType: + case TypeProto::kMapType: + case TypeProto::kOpaqueType: + case TypeProto::VALUE_NOT_SET: + default: + return false; + } +} + void NodeArg::SetShape(const TensorShapeProto& shape) { const auto type_case = node_arg_info_.type().value_case(); switch (type_case) { diff --git a/orttraining/orttraining/core/graph/pipeline_transformer.cc b/orttraining/orttraining/core/graph/pipeline_transformer.cc index 6fde902670..2f6762be49 100644 --- a/orttraining/orttraining/core/graph/pipeline_transformer.cc +++ b/orttraining/orttraining/core/graph/pipeline_transformer.cc @@ -41,9 +41,12 @@ void AddNewNodeArg(Graph& graph, new_node_args.push_back(&new_node_arg); } -// gradient graph can contain some dangling leaf nodes. Add them all to WaitEvent -// backward node's input. -void FindLeafNodes(Graph& graph, std::vector& input_args) { +// Gradient graph can contain some dangling leaf nodes. This function collects +// their first output using the returned vector. +std::vector FindBackwardLeafNodes(Graph& graph) { + // leaf_node_args[i] is the i-th leaf node's first output in the backward + // pass. + std::vector leaf_node_args; for (auto& node : graph.Nodes()) { if (!IsBackward(node)) { // only check backward node @@ -59,11 +62,52 @@ void FindLeafNodes(Graph& graph, std::vector& input_args) { } } if (!find_consumer_nodes && outputs.size() > 0) { - input_args.push_back(outputs[0]); + leaf_node_args.push_back(outputs[0]); } } + + return leaf_node_args; }; +// This function converts tensor NodeArg to a boolean scalar so that last +// backward RecordEvent doesn't block the early release of large gradient +// tensors. If we connect gradient tensors directly to that RecordEvent, +// we need a memory block (as large as a whole model) to store gradient +// for each trainable tensor until the end of backward pass. +// +// The newly created boolean scalar may be appended to signal_args. If +// signal_args is empty, the source of signal_args[i] would be tensor_args[i]. +void ConvertTensorToBoolSignal( + Graph& graph, + const std::vector& tensor_args, + std::vector& signal_args) { + + for (auto tensor_arg: tensor_args) { + // Declare the scalar signal this "tensor_arg" will be converted into. + auto signal_arg = &CreateTypedNodeArg( + graph, + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + "signal_" + tensor_arg->Name() + ); + + // Add the new scalar to user-specified vector. + signal_args.push_back(signal_arg); + + // Add tensor-to-scalar conversion node. + const auto name = graph.GenerateNodeName("tensor_to_scalar_signal"); + std::vector input_args{tensor_arg}; + std::vector output_args{signal_arg}; + graph.AddNode( + name, + "Group", + "", + input_args, + output_args, + nullptr, + kMSDomain); + } +} + // Return mirror variables for node_arg with a different name. NodeArg& CreateNodeArg(Graph& graph, const NodeArg* base_arg) { const auto& new_name = graph.GenerateNodeArgName(base_arg->Name()); @@ -133,7 +177,19 @@ Node* AddBackwardRecord(Graph& graph, std::begin(backward_send->MutableOutputDefs()), std::end(backward_send->MutableOutputDefs())); } - FindLeafNodes(graph, input_args); + + // Find all leaf nodes' frist inputs. They are used togehter as control edges + // to determine if backward pass is finished. + auto backward_leaf_node_args = FindBackwardLeafNodes(graph); + + // For each leaf tensor in the backward pass, we use "Group" operator to + // convert it to a boolean scalar so that the original leaf's memory can be + // released earlier. + + // TODO: use full list instead of the first element after changining + // topological sort to depth-first from inputs. + std::vector sub_backward_leaf_node_args{backward_leaf_node_args[0]}; + ConvertTensorToBoolSignal(graph, backward_leaf_node_args, input_args); // Optimizer will be added after applying pipeline transformer. To support partial graph evaluation, // the added Record backward op will have its first passthrough input as output. @@ -1040,8 +1096,17 @@ common::Status GenerateSubgraph(Graph& graph, const Node* start_node) { } } - // update the grah with only visited inputs and outputs - graph.SetInputs({visited_inputs.begin(), visited_inputs.end()}); + // If the following line is uncommented, middle and last pipeline stages may + // have unresolved symbolic shapes. The reason is that some symbolic shapes + // are defined for the original inputs, if original inputs are removed, we + // loss the hit to resolve symbolic shapes. For example, if an original + // input's shape is [batch, sequence, 1024], that input should be provided as + // a feed to all pipeline stages. Otherwise, we don't know the actual values + // of "batch" and "sequence". + // + // graph.SetInputs({visited_inputs.begin(), visited_inputs.end()}); + + // update the grah with only visited outputs graph.SetOutputs({visited_outputs.begin(), visited_outputs.end()}); graph.SetGraphResolveNeeded(); graph.SetGraphProtoSyncNeeded();