diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index caa736b000..f71c2ace69 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -95,12 +95,20 @@ class Node { /** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */ const std::string& Domain() const noexcept { return domain_; } + /** Gets the Node's exection priority. + @remarks Lower value means higher priority */ + int Priority() const noexcept { return priority_; }; + + /** Sets the execution priority of a node. + @remarks Lower value means higher priority */ + void SetPriority(int priority) noexcept; + /** Gets the node description. */ const std::string& Description() const noexcept { return description_; } /** Gets the Node's Node::Type. */ Node::Type NodeType() const noexcept { return node_type_; } - + /** Gets the opset version that the Node's operator was first defined in. @returns Opset version. If -1 the Node's operator has not been set. @remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build @@ -507,6 +515,9 @@ class Node { const ONNX_NAMESPACE::OpSchema* op_ = nullptr; #endif + // Execution priority, lower value for higher priority + int priority_ = 0; + // set from op_->SinceVersion() or via deserialization when OpSchema is not available int since_version_ = -1; @@ -850,6 +861,13 @@ class Graph { const std::function& comp, const std::function& stop) const; + /** Performs topological sort with Kahn's algorithm on the graph/s. + @param enter Visit function that will be invoked on a node when it is visited. + @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic. + */ + void KahnsTopologicalSort(const std::function& enter, + const std::function& comp) const; + /** Gets the map of operator domains to their opset versions. */ const std::unordered_map& DomainToVersionMap() const noexcept { return domain_to_version_; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index f68afda130..77cba8d3d5 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -4,6 +4,7 @@ #pragma once #include "core/graph/graph.h" +#include "core/framework/session_options.h" namespace onnxruntime { class Function; @@ -87,7 +88,7 @@ class GraphViewer { int MaxNodeIndex() const noexcept; /** Gets the NodeIndex values for the Graph nodes, sorted into topological order. */ - const std::vector& GetNodesInTopologicalOrder() const; + const std::vector& GetNodesInTopologicalOrder(ExecutionOrder order = ExecutionOrder::DEFAULT) const; /** Gets the NodeIndex values for the root nodes in the Graph. @@ -144,6 +145,10 @@ class GraphViewer { // The NodeIndex values of the graph nodes sorted in topological order. std::vector nodes_in_topological_order_; + + // The NodeIndex values of the graph nodes sorted in topological order with priority. + std::vector nodes_in_topological_order_with_priority_; + // Graph root nodes. std::vector root_nodes_; }; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 4e79618da8..711c6c28c0 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -768,7 +768,7 @@ class PlannerImpl { }; // namespace onnxruntime Status PlannerImpl::CreatePlan() { - auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder(); + auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder(context_.GetExecutionOrder()); int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; diff --git a/onnxruntime/core/framework/allocation_planner.h b/onnxruntime/core/framework/allocation_planner.h index 4130bf7481..3cfe12b930 100644 --- a/onnxruntime/core/framework/allocation_planner.h +++ b/onnxruntime/core/framework/allocation_planner.h @@ -28,22 +28,28 @@ class ISequentialPlannerContext { // If it returns true, planner won't reuse output tensors // see PlannerImpl::ComputeReusePlan virtual bool IsParallelExecutionEnabled() const { return false; } + + virtual ExecutionOrder GetExecutionOrder() const { return ExecutionOrder::DEFAULT; } }; class SequentialPlannerContext : public ISequentialPlannerContext { public: - SequentialPlannerContext(ExecutionMode execution_mode) - : m_execution_mode(execution_mode) { + SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order) + : execution_mode_(execution_mode), + exection_order_(execution_order) { } const ONNX_NAMESPACE::TensorShapeProto* GetShape(const onnxruntime::NodeArg& arg) const override { return arg.Shape(); } - bool IsParallelExecutionEnabled() const override { return m_execution_mode == ExecutionMode::ORT_PARALLEL; } + bool IsParallelExecutionEnabled() const override { return execution_mode_ == ExecutionMode::ORT_PARALLEL; } + + ExecutionOrder GetExecutionOrder() const override { return exection_order_; } private: - ExecutionMode m_execution_mode = ExecutionMode::ORT_SEQUENTIAL; + ExecutionMode execution_mode_ = ExecutionMode::ORT_SEQUENTIAL; + ExecutionOrder exection_order_ = ExecutionOrder::DEFAULT; }; class SequentialPlanner { diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index d8136087c5..8fb05eea0b 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -11,12 +11,25 @@ namespace onnxruntime { +enum class ExecutionOrder { + DEFAULT = 0, // default topological sort + PRIORITY_BASED = 1 // priority-based topological sort +}; + enum class FreeDimensionOverrideType { Invalid = 0, Denotation = 1, Name = 2 }; +enum class ExecutionPriority : int { + GLOBAL_HIGHT = -100, + LOCAL_HIGH = -10, + DEFAULT = 0, + LOCAL_LOW = 10, + GLOBAL_LOW = 100 +}; + struct FreeDimensionOverride { std::string dim_identifier; FreeDimensionOverrideType dim_identifer_type; @@ -29,6 +42,9 @@ struct FreeDimensionOverride { struct SessionOptions { ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL; + // set the execution order of the graph + ExecutionOrder execution_order = ExecutionOrder::DEFAULT; + // enable profiling for this session. bool enable_profiling = false; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 02436626a5..edb54925c8 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -832,7 +832,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string #include #include +#include #include "gsl/gsl" #include "core/common/logging/logging.h" @@ -420,6 +421,10 @@ const Node* Node::NodeConstIterator::operator->() const { return &(operator*()); } +void Node::SetPriority(int priority) noexcept { + priority_ = priority; +} + #if !defined(ORT_MINIMAL_BUILD) void Node::SetNodeType(Node::Type node_type) noexcept { @@ -677,6 +682,7 @@ void Node::Init(const std::string& name, definitions_.input_defs = input_args; definitions_.output_defs = output_args; domain_ = domain; + priority_ = 0; if (kOnnxDomainAlias == domain_) { domain_ = kOnnxDomain; } @@ -1560,6 +1566,44 @@ void Graph::ReverseDFSFrom(const std::vector& from, } } +void Graph::KahnsTopologicalSort(const std::function& enter, + const std::function& comp) const { + std::unordered_map in_degree; + std::priority_queue, decltype(comp)> to_visit(comp); + std::vector topo_order; + + for (auto& node : Nodes()) { + size_t input_edge_count = node.GetInputEdgesCount(); + in_degree.insert({node.Index(), input_edge_count}); + if (input_edge_count == 0) { + to_visit.push(&node); + } + } + + while (!to_visit.empty()) { + const Node* current = to_visit.top(); + to_visit.pop(); + + if (!current) continue; + + if (enter) { + enter(current); + } + + for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { + in_degree[node_it->Index()]--; + + if (in_degree[node_it->Index()] == 0) { + to_visit.push(&*node_it); + } + } + topo_order.push_back(current->Index()); + } + + if (NumberOfNodes() != static_cast(topo_order.size())) { + ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); + } +} #if !defined(ORT_MINIMAL_BUILD) GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...) diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index aa29caef9b..fc022149d5 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,15 +14,45 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { return n1->Index() < n2->Index(); } +struct PriorityNodeCompare { + inline bool IsHighPri(const Node* n) const { + static const std::unordered_set high_pri_ops = {"Shape", "Size"}; + return high_pri_ops.find(n->OpType()) != high_pri_ops.end(); + } + + // Used for std::priority_queue + // If return false, n1 will be output first + // If return true, n2 will be output first + bool operator()(const Node* n1, const Node* n2) const { + // nodes in global high priorty list will be output first + if (IsHighPri(n1) != IsHighPri(n2)) { + return IsHighPri(n2); + } + + // nodes with lower priority value will be output first + if (n1->Priority() != n2->Priority()) { + return n1->Priority() > n2->Priority(); + } + + // otherwise, nodes with lower index will be output first + return n1->Index() > n2->Index(); + } +}; + GraphViewer::GraphViewer(const Graph& graph) { graph_ = &graph; std::vector leaf_nodes; for (auto& node : graph_->Nodes()) { + // This is a leaf node (without any output node) if (node.OutputNodesBegin() == node.OutputNodesEnd()) { - // This is a leaf node (without any output node). leaf_nodes.push_back(&node); } + // This is a root node (without any input node) + if (node.InputEdgesBegin() == node.InputEdgesEnd()) { + root_nodes_.push_back(node.Index()); + } } + graph.ReverseDFSFrom( leaf_nodes, nullptr, @@ -31,11 +61,11 @@ GraphViewer::GraphViewer(const Graph& graph) { }, NodeCompare()); - for (auto& node : graph_->Nodes()) { - if (node.InputEdgesBegin() == node.InputEdgesEnd()) { - root_nodes_.push_back(node.Index()); - } - } + graph.KahnsTopologicalSort( + [this](const Node* n) { + nodes_in_topological_order_with_priority_.push_back(n->Index()); + }, + PriorityNodeCompare()); } // Graph name. @@ -92,8 +122,15 @@ int GraphViewer::MaxNodeIndex() const noexcept { return graph_->MaxNodeIndex(); } -const std::vector& GraphViewer::GetNodesInTopologicalOrder() const { - return nodes_in_topological_order_; +const std::vector& GraphViewer::GetNodesInTopologicalOrder(ExecutionOrder order) const { + switch (order) { + case ExecutionOrder::DEFAULT: + return nodes_in_topological_order_; + case ExecutionOrder::PRIORITY_BASED: + return nodes_in_topological_order_with_priority_; + default: + ORT_THROW("Invalide ExecutionOrder"); + } } const std::vector& GraphViewer::GetRootNodes() const { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ae902322e5..52720e0bbf 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1154,676 +1154,676 @@ common::Status InferenceSession::Initialize() { } #endif // !defined(ORT_MINIMAL_BUILD) - session_state_->ResolveMemoryPatternFlag(); - is_inited_ = true; + session_state_->ResolveMemoryPatternFlag(); + is_inited_ = true; - // we don't directly use the ORT format bytes currently, so free those now - std::vector().swap(ort_format_model_bytes_); + // we don't directly use the ORT format bytes currently, so free those now + std::vector().swap(ort_format_model_bytes_); - // and log telemetry - bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); - env.GetTelemetryProvider().LogSessionCreation( - session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), - model_->MainGraph().DomainToVersionMap(), model_->MainGraph().Name(), model_->MetaData(), - telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs); - LOGS(*session_logger_, INFO) << "Session successfully initialized."; - } - ORT_CATCH(const NotImplementedException& ex) { - ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what()); + // and log telemetry + bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); + env.GetTelemetryProvider().LogSessionCreation( + session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), + model_->MainGraph().DomainToVersionMap(), model_->MainGraph().Name(), model_->MetaData(), + telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs); + LOGS(*session_logger_, INFO) << "Session successfully initialized."; + } + ORT_CATCH(const NotImplementedException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what()); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + }); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Exception during initialization: ", ex.what()); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + }); + } + ORT_CATCH(...) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Encountered unknown exception in Initialize()"); LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - }); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Exception during initialization: ", ex.what()); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - }); - } - ORT_CATCH(...) { - status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Encountered unknown exception in Initialize()"); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + } + + if (session_profiler_.IsEnabled()) { + session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp); + } + + if (status.IsOK()) { + for (auto& xp : execution_providers_) { + auto end_status = xp->OnSessionInitializationEnd(); + if (status.IsOK()) { + status = end_status; + } + } + } + + return status; } - if (session_profiler_.IsEnabled()) { - session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp); - } - - if (status.IsOK()) { - for (auto& xp : execution_providers_) { - auto end_status = xp->OnSessionInitializationEnd(); - if (status.IsOK()) { - status = end_status; + // This method should be called from within Initialize() only and before the creation of the session state. + // This ensures all providers have been registered in the session and the session state is consistent with the providers. + void InferenceSession::UpdateProvidersWithSharedAllocators() { + using namespace std; + const auto& provider_ids = execution_providers_.GetIds(); + for (const auto& one_shared_alloc : environment_.GetRegisteredSharedAllocators()) { + for (const auto& id : provider_ids) { + auto* provider_ptr = execution_providers_.Get(id); + provider_ptr->ReplaceAllocator(one_shared_alloc); } } } - return status; -} - -// This method should be called from within Initialize() only and before the creation of the session state. -// This ensures all providers have been registered in the session and the session state is consistent with the providers. -void InferenceSession::UpdateProvidersWithSharedAllocators() { - using namespace std; - const auto& provider_ids = execution_providers_.GetIds(); - for (const auto& one_shared_alloc : environment_.GetRegisteredSharedAllocators()) { - for (const auto& id : provider_ids) { - auto* provider_ptr = execution_providers_.Get(id); - provider_ptr->ReplaceAllocator(one_shared_alloc); - } - } -} - -int InferenceSession::GetCurrentNumRuns() const { - return current_num_runs_.load(); -} - -const std::vector& InferenceSession::GetRegisteredProviderTypes() const { - return execution_providers_.GetIds(); -} - -const ProviderOptionsMap& InferenceSession::GetAllProviderOptions() const { - return execution_providers_.GetAllProviderOptions(); -} - -const SessionOptions& InferenceSession::GetSessionOptions() const { - return session_options_; -} - -const DataTransferManager& InferenceSession::GetDataTransferManager() const { - return data_transfer_mgr_; -} - -common::Status InferenceSession::CheckShapes(const std::string& input_name, const TensorShape& input_shape, - const TensorShape& expected_shape) const { - auto input_shape_sz = input_shape.NumDimensions(); - auto expected_shape_sz = expected_shape.NumDimensions(); - if (input_shape_sz != expected_shape_sz) { - std::ostringstream ostr; - ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz - << " Please fix either the inputs or the model."; - return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str()); + int InferenceSession::GetCurrentNumRuns() const { + return current_num_runs_.load(); } - std::vector invalid_dim_indices; - for (size_t i = 0; i < input_shape_sz; ++i) { - if (expected_shape[i] < 0) { - continue; // this represents a symbolic shape dimension - } - if (input_shape[i] != expected_shape[i]) { - invalid_dim_indices.push_back(i); - } + const std::vector& InferenceSession::GetRegisteredProviderTypes() const { + return execution_providers_.GetIds(); } - if (!invalid_dim_indices.empty()) { - std::ostringstream ostr; - ostr << "Got invalid dimensions for input: " << input_name << " for the following indices\n"; - for (size_t i = 0, end = invalid_dim_indices.size(); i < end; ++i) { - size_t idx = invalid_dim_indices[i]; - ostr << " index: " << idx << " Got: " << input_shape[idx] << " Expected: " << expected_shape[idx] << "\n"; - } - ostr << " Please fix either the inputs or the model."; - return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str()); + const ProviderOptionsMap& InferenceSession::GetAllProviderOptions() const { + return execution_providers_.GetAllProviderOptions(); } - return Status::OK(); -} + const SessionOptions& InferenceSession::GetSessionOptions() const { + return session_options_; + } + + const DataTransferManager& InferenceSession::GetDataTransferManager() const { + return data_transfer_mgr_; + } + + common::Status InferenceSession::CheckShapes(const std::string& input_name, const TensorShape& input_shape, + const TensorShape& expected_shape) const { + auto input_shape_sz = input_shape.NumDimensions(); + auto expected_shape_sz = expected_shape.NumDimensions(); + if (input_shape_sz != expected_shape_sz) { + std::ostringstream ostr; + ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz + << " Please fix either the inputs or the model."; + return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str()); + } + + std::vector invalid_dim_indices; + for (size_t i = 0; i < input_shape_sz; ++i) { + if (expected_shape[i] < 0) { + continue; // this represents a symbolic shape dimension + } + if (input_shape[i] != expected_shape[i]) { + invalid_dim_indices.push_back(i); + } + } + + if (!invalid_dim_indices.empty()) { + std::ostringstream ostr; + ostr << "Got invalid dimensions for input: " << input_name << " for the following indices\n"; + for (size_t i = 0, end = invalid_dim_indices.size(); i < end; ++i) { + size_t idx = invalid_dim_indices[i]; + ostr << " index: " << idx << " Got: " << input_shape[idx] << " Expected: " << expected_shape[idx] << "\n"; + } + ostr << " Please fix either the inputs or the model."; + return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str()); + } -static common::Status CheckTypes(MLDataType actual, MLDataType expected) { - if (actual == expected) { return Status::OK(); } + + static common::Status CheckTypes(MLDataType actual, MLDataType expected) { + if (actual == expected) { + return Status::OK(); + } #ifdef ORT_NO_RTTI - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unexpected input data type"); + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unexpected input data type"); #else auto actual_name = std::string(typeid(*actual).name()); auto expected_name = std::string(typeid(*expected).name()); return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unexpected input data type. Actual: (" + actual_name + ") , expected: (" + expected_name + ")"); #endif -} - -common::Status InferenceSession::ValidateInputs(const std::vector& feed_names, - const std::vector& feeds) const { - if (feed_names.size() != feeds.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(), - "elements, but feeds has ", feeds.size(), " elements."); } - for (size_t i = 0; i < feeds.size(); ++i) { - const auto& feed_name = feed_names[i]; - - auto iter = input_def_map_.find(feed_name); - if (input_def_map_.end() == iter) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name); + common::Status InferenceSession::ValidateInputs(const std::vector& feed_names, + const std::vector& feeds) const { + if (feed_names.size() != feeds.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(), + "elements, but feeds has ", feeds.size(), " elements."); } - auto expected_type = iter->second.ml_data_type; - auto& input_ml_value = feeds.at(i); - if (input_ml_value.IsTensor()) { - // check for type - if (!expected_type->IsTensorType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, - " is not expected to be of type tensor."); - } - auto expected_element_type = expected_type->AsTensorType()->GetElementType(); - auto input_element_type = input_ml_value.Get().DataType(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type)); + for (size_t i = 0; i < feeds.size(); ++i) { + const auto& feed_name = feed_names[i]; - // check for shape - const auto& expected_shape = iter->second.tensor_shape; - if (expected_shape.NumDimensions() > 0) { - const auto& input_shape = input_ml_value.Get().Shape(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape)); + auto iter = input_def_map_.find(feed_name); + if (input_def_map_.end() == iter) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name); } - } else if (input_ml_value.IsSparseTensor()) { + + auto expected_type = iter->second.ml_data_type; + auto& input_ml_value = feeds.at(i); + if (input_ml_value.IsTensor()) { + // check for type + if (!expected_type->IsTensorType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, + " is not expected to be of type tensor."); + } + auto expected_element_type = expected_type->AsTensorType()->GetElementType(); + auto input_element_type = input_ml_value.Get().DataType(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type)); + + // check for shape + const auto& expected_shape = iter->second.tensor_shape; + if (expected_shape.NumDimensions() > 0) { + const auto& input_shape = input_ml_value.Get().Shape(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape)); + } + } else if (input_ml_value.IsSparseTensor()) { #if !defined(ORT_MINIMAL_BUILD) - if (!expected_type->IsSparseTensorType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, - " is not expected to be of type sparse tensor."); - } - auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType(); - auto input_element_type = input_ml_value.Get().Values().DataType(); - // TODO: In the future, when sparsetensors are in use, find out how to properly verify the shape - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type)); + if (!expected_type->IsSparseTensorType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, + " is not expected to be of type sparse tensor."); + } + auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType(); + auto input_element_type = input_ml_value.Get().Values().DataType(); + // TODO: In the future, when sparsetensors are in use, find out how to properly verify the shape + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type)); #else return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name ", feed_name, " is a sparse tensor, which is not supported in this build."); #endif - } else if (input_ml_value.IsTensorSequence()) { - if (!expected_type->IsTensorSequenceType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, - " is not expected to be of type tensor sequence."); + } else if (input_ml_value.IsTensorSequence()) { + if (!expected_type->IsTensorSequenceType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, + " is not expected to be of type tensor sequence."); + } + auto expected_element_type = expected_type->AsSequenceTensorBase()->GetElementType(); + auto input_element_type = input_ml_value.Get().DataType(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type)); + } else { + auto input_type = input_ml_value.Type(); + ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_type, expected_type)); } - auto expected_element_type = expected_type->AsSequenceTensorBase()->GetElementType(); - auto input_element_type = input_ml_value.Get().DataType(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type)); - } else { - auto input_type = input_ml_value.Type(); - ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_type, expected_type)); } + + return Status::OK(); } - return Status::OK(); -} - -common::Status InferenceSession::ValidateOutputs(const std::vector& output_names, - const std::vector* p_fetches) const { - if (p_fetches == nullptr) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL"); - } - - if (output_names.empty()) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested."); - } - - if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) { - std::ostringstream ostr; - ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size() - << "p_fetches->size(): " << p_fetches->size(); - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); - } - - for (const auto& name : output_names) { - if (model_output_names_.find(name) == model_output_names_.end()) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name); + common::Status InferenceSession::ValidateOutputs(const std::vector& output_names, + const std::vector* p_fetches) const { + if (p_fetches == nullptr) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL"); } + + if (output_names.empty()) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested."); + } + + if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) { + std::ostringstream ostr; + ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size() + << "p_fetches->size(): " << p_fetches->size(); + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); + } + + for (const auto& name : output_names) { + if (model_output_names_.find(name) == model_output_names_.end()) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name); + } + } + + // TODO add more validation here like checking shape of the allocated buffers + + return common::Status::OK(); } - // TODO add more validation here like checking shape of the allocated buffers - - return common::Status::OK(); -} - -Status InferenceSession::Run(const RunOptions& run_options, - const std::vector& feed_names, const std::vector& feeds, - const std::vector& output_names, std::vector* p_fetches, - const std::vector* p_fetches_device_info) { - TimePoint tp; - if (session_profiler_.IsEnabled()) { - tp = session_profiler_.StartTime(); - } + Status InferenceSession::Run(const RunOptions& run_options, + const std::vector& feed_names, const std::vector& feeds, + const std::vector& output_names, std::vector* p_fetches, + const std::vector* p_fetches_device_info) { + TimePoint tp; + if (session_profiler_.IsEnabled()) { + tp = session_profiler_.StartTime(); + } #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT - TraceLoggingActivity ortrun_activity; - ortrun_activity.SetRelatedActivity(session_activity); - TraceLoggingWriteStart(ortrun_activity, "OrtRun"); + TraceLoggingActivity ortrun_activity; + ortrun_activity.SetRelatedActivity(session_activity); + TraceLoggingWriteStart(ortrun_activity, "OrtRun"); #endif - Status retval = Status::OK(); - const Env& env = Env::Default(); + Status retval = Status::OK(); + const Env& env = Env::Default(); - std::vector exec_providers_to_stop; - exec_providers_to_stop.reserve(execution_providers_.NumProviders()); + std::vector exec_providers_to_stop; + exec_providers_to_stop.reserve(execution_providers_.NumProviders()); - ORT_TRY { - if (!is_inited_) { - LOGS(*session_logger_, ERROR) << "Session was not initialized"; - return Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); - } - - // log evaluation start to trace logging provider - env.GetTelemetryProvider().LogEvaluationStart(); - - ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds)); - ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches)); - - FeedsFetchesInfo info(feed_names, output_names, session_state_->GetOrtValueNameIdxMap()); - FeedsFetchesManager feeds_fetches_manager{std::move(info)}; - - if (p_fetches_device_info) { - // populate the target device info. ignored if pre-allocated fetches are provided - const auto& fetch_device_info = *p_fetches_device_info; - auto& fetch_info = feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo(); - - for (size_t i = 0, end = output_names.size(); i < end; ++i) { - fetch_info[i].target_device = fetch_device_info[i]; + ORT_TRY { + if (!is_inited_) { + LOGS(*session_logger_, ERROR) << "Session was not initialized"; + return Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); } - } - if (!run_options.run_tag.empty()) { - LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag; - } + // log evaluation start to trace logging provider + env.GetTelemetryProvider().LogEvaluationStart(); - ++current_num_runs_; + ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds)); + ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches)); - // TODO should we add this exec to the list of executors? i guess its not needed now? + FeedsFetchesInfo info(feed_names, output_names, session_state_->GetOrtValueNameIdxMap()); + FeedsFetchesManager feeds_fetches_manager{std::move(info)}; - // scope of owned_run_logger is just the call to Execute. - // If Execute ever becomes async we need a different approach - std::unique_ptr owned_run_logger; - auto run_logger = CreateLoggerForRun(run_options, owned_run_logger); + if (p_fetches_device_info) { + // populate the target device info. ignored if pre-allocated fetches are provided + const auto& fetch_device_info = *p_fetches_device_info; + auto& fetch_info = feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo(); - // info all execution providers InferenceSession:Run started - // TODO: only call OnRunStart for all providers in-use - for (auto& xp : execution_providers_) { - // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); - if (status.IsOK()) - exec_providers_to_stop.push_back(xp.get()); + for (size_t i = 0, end = output_names.size(); i < end; ++i) { + fetch_info[i].target_device = fetch_device_info[i]; + } + } - return status; - }; + if (!run_options.run_tag.empty()) { + LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag; + } - ORT_CHECK_AND_SET_RETVAL(start_func()); - } + ++current_num_runs_; + + // TODO should we add this exec to the list of executors? i guess its not needed now? + + // scope of owned_run_logger is just the call to Execute. + // If Execute ever becomes async we need a different approach + std::unique_ptr owned_run_logger; + auto run_logger = CreateLoggerForRun(run_options, owned_run_logger); + + // info all execution providers InferenceSession:Run started + // TODO: only call OnRunStart for all providers in-use + for (auto& xp : execution_providers_) { + // call OnRunStart and add to exec_providers_to_stop if successful + auto start_func = [&xp, &exec_providers_to_stop]() { + auto status = xp->OnRunStart(); + if (status.IsOK()) + exec_providers_to_stop.push_back(xp.get()); + + return status; + }; + + ORT_CHECK_AND_SET_RETVAL(start_func()); + } #if !defined(ORT_MINIMAL_BUILD) - if (run_options.only_execute_path_to_fetches) { - session_state_->UpdateToBeExecutedNodes(feeds_fetches_manager.GetFeedsFetchesInfo().fetches_mlvalue_idxs); - } + if (run_options.only_execute_path_to_fetches) { + session_state_->UpdateToBeExecutedNodes(feeds_fetches_manager.GetFeedsFetchesInfo().fetches_mlvalue_idxs); + } #endif - // execute the graph - ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, - session_options_.execution_mode, run_options.terminate, run_logger, - run_options.only_execute_path_to_fetches)); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - retval = Status(common::ONNXRUNTIME, common::FAIL, e.what()); - }); - } - ORT_CATCH(...) { - retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()"); - } + // execute the graph + ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, + session_options_.execution_mode, run_options.terminate, run_logger, + run_options.only_execute_path_to_fetches)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + retval = Status(common::ONNXRUNTIME, common::FAIL, e.what()); + }); + } + ORT_CATCH(...) { + retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()"); + } - // info all execution providers InferenceSession:Run ended - for (auto* xp : exec_providers_to_stop) { - auto status = xp->OnRunEnd(); - ORT_CHECK_AND_SET_RETVAL(status); - } + // info all execution providers InferenceSession:Run ended + for (auto* xp : exec_providers_to_stop) { + auto status = xp->OnRunEnd(); + ORT_CHECK_AND_SET_RETVAL(status); + } - --current_num_runs_; + --current_num_runs_; - // keep track of telemetry - ++telemetry_.total_runs_since_last_; - telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp); + // keep track of telemetry + ++telemetry_.total_runs_since_last_; + telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp); - // time to send telemetry? - if (TimeDiffMicroSeconds(telemetry_.time_sent_last_) > telemetry_.kDurationBetweenSending) { - // send the telemetry - env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_, - telemetry_.total_run_duration_since_last_); - // reset counters - telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now(); - telemetry_.total_runs_since_last_ = 0; - telemetry_.total_run_duration_since_last_ = 0; - } + // time to send telemetry? + if (TimeDiffMicroSeconds(telemetry_.time_sent_last_) > telemetry_.kDurationBetweenSending) { + // send the telemetry + env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_, + telemetry_.total_run_duration_since_last_); + // reset counters + telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now(); + telemetry_.total_runs_since_last_ = 0; + telemetry_.total_run_duration_since_last_ = 0; + } - // log evaluation stop to trace logging provider - env.GetTelemetryProvider().LogEvaluationStop(); + // log evaluation stop to trace logging provider + env.GetTelemetryProvider().LogEvaluationStop(); - // send out profiling events (optional) - if (session_profiler_.IsEnabled()) { - session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp); - } + // send out profiling events (optional) + if (session_profiler_.IsEnabled()) { + session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp); + } #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT - TraceLoggingWriteStop(ortrun_activity, "OrtRun"); + TraceLoggingWriteStop(ortrun_activity, "OrtRun"); #endif - return retval; -} - -common::Status InferenceSession::Run(const NameMLValMap& feeds, const std::vector& output_names, - std::vector* p_fetches) { - return Run(RunOptions(), feeds, output_names, p_fetches); -} - -common::Status InferenceSession::Run(const RunOptions& run_options, const NameMLValMap& feeds_map, - const std::vector& output_names, std::vector* p_fetches) { - std::vector feed_names; - std::vector feeds; - - auto num_feeds = feeds_map.size(); - feed_names.reserve(num_feeds); - feeds.reserve(num_feeds); - - for (auto& pair : feeds_map) { - feed_names.push_back(pair.first); - feeds.push_back(pair.second); + return retval; } - return Run(run_options, feed_names, feeds, output_names, p_fetches, nullptr); -} + common::Status InferenceSession::Run(const NameMLValMap& feeds, const std::vector& output_names, + std::vector* p_fetches) { + return Run(RunOptions(), feeds, output_names, p_fetches); + } -std::pair InferenceSession::GetModelMetadata() const { - { - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + common::Status InferenceSession::Run(const RunOptions& run_options, const NameMLValMap& feeds_map, + const std::vector& output_names, std::vector* p_fetches) { + std::vector feed_names; + std::vector feeds; + + auto num_feeds = feeds_map.size(); + feed_names.reserve(num_feeds); + feeds.reserve(num_feeds); + + for (auto& pair : feeds_map) { + feed_names.push_back(pair.first); + feeds.push_back(pair.second); } + + return Run(run_options, feed_names, feeds, output_names, p_fetches, nullptr); } - return std::make_pair(common::Status::OK(), &model_metadata_); -} - -std::pair InferenceSession::GetModelInputs() const { - { - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + std::pair InferenceSession::GetModelMetadata() const { + { + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + } } + + return std::make_pair(common::Status::OK(), &model_metadata_); } - // return required inputs (excludes any inputs used for overriding initializers) - return std::make_pair(common::Status::OK(), &model_->MainGraph().GetInputs()); -} - -std::pair InferenceSession::GetOverridableInitializers() const { - { - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + std::pair InferenceSession::GetModelInputs() const { + { + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + } } + + // return required inputs (excludes any inputs used for overriding initializers) + return std::make_pair(common::Status::OK(), &model_->MainGraph().GetInputs()); } - // returns a list of initializers that can be overriden. - return std::make_pair(common::Status::OK(), &model_->MainGraph().GetOverridableInitializers()); -} - -std::pair InferenceSession::GetModelOutputs() const { - { - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + std::pair InferenceSession::GetOverridableInitializers() const { + { + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + } } + + // returns a list of initializers that can be overriden. + return std::make_pair(common::Status::OK(), &model_->MainGraph().GetOverridableInitializers()); } - return std::make_pair(common::Status::OK(), &output_def_list_); -} - -common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { - { - std::lock_guard l(session_mutex_); - if (!is_inited_) { - LOGS(*session_logger_, ERROR) << "Session was not initialized"; - return common::Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); + std::pair InferenceSession::GetModelOutputs() const { + { + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); + } } + + return std::make_pair(common::Status::OK(), &output_def_list_); } - // private constructor, can't use make_unique - *io_binding = std::unique_ptr(new IOBinding(*session_state_)); - return Status::OK(); -} + common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { + { + std::lock_guard l(session_mutex_); + if (!is_inited_) { + LOGS(*session_logger_, ERROR) << "Session was not initialized"; + return common::Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); + } + } -common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& io_binding) { - // TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it? - // io_binding.SynchronizeInputs(); - return Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(), - &io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo()); -} + // private constructor, can't use make_unique + *io_binding = std::unique_ptr(new IOBinding(*session_state_)); + return Status::OK(); + } -common::Status InferenceSession::Run(IOBinding& io_binding) { - RunOptions run_options; - return Run(run_options, io_binding); -} + common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& io_binding) { + // TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it? + // io_binding.SynchronizeInputs(); + return Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(), + &io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo()); + } -template -void InferenceSession::StartProfiling(const std::basic_string& file_prefix) { - std::basic_ostringstream ss; - ss << file_prefix << "_" << GetCurrentTimeString() << ".json"; - session_profiler_.StartProfiling(ss.str()); -} + common::Status InferenceSession::Run(IOBinding& io_binding) { + RunOptions run_options; + return Run(run_options, io_binding); + } -void InferenceSession::StartProfiling(const std::string& file_prefix) { - StartProfiling(file_prefix); -} + template + void InferenceSession::StartProfiling(const std::basic_string& file_prefix) { + std::basic_ostringstream ss; + ss << file_prefix << "_" << GetCurrentTimeString() << ".json"; + session_profiler_.StartProfiling(ss.str()); + } + + void InferenceSession::StartProfiling(const std::string& file_prefix) { + StartProfiling(file_prefix); + } #ifdef _WIN32 -void InferenceSession::StartProfiling(const std::wstring& file_prefix) { - StartProfiling(file_prefix); -} + void InferenceSession::StartProfiling(const std::wstring& file_prefix) { + StartProfiling(file_prefix); + } #endif -void InferenceSession::StartProfiling(const logging::Logger* logger_ptr) { - session_profiler_.StartProfiling(logger_ptr); -} - -std::string InferenceSession::EndProfiling() { - if (is_model_loaded_) { - if (session_profiler_.IsEnabled()) { - return session_profiler_.EndProfiling(); - } else { - LOGS(*session_logger_, VERBOSE) << "Profiler is disabled."; - return std::string(); - } + void InferenceSession::StartProfiling(const logging::Logger* logger_ptr) { + session_profiler_.StartProfiling(logger_ptr); } - LOGS(*session_logger_, ERROR) << "Could not write a profile because no model was loaded."; - return std::string(); -} -const profiling::Profiler& InferenceSession::GetProfiling() const { - return session_profiler_; -} + std::string InferenceSession::EndProfiling() { + if (is_model_loaded_) { + if (session_profiler_.IsEnabled()) { + return session_profiler_.EndProfiling(); + } else { + LOGS(*session_logger_, VERBOSE) << "Profiler is disabled."; + return std::string(); + } + } + LOGS(*session_logger_, ERROR) << "Could not write a profile because no model was loaded."; + return std::string(); + } -AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { - return session_state_->GetAllocator(mem_info); -} + const profiling::Profiler& InferenceSession::GetProfiling() const { + return session_profiler_; + } + + AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { + return session_state_->GetAllocator(mem_info); + } #if !defined(ORT_MINIMAL_BUILD) -// assumes model has already been loaded before -common::Status InferenceSession::DoPostLoadProcessing(onnxruntime::Model& model) { - // TODO add other post load processing here - common::Status status = SaveModelMetadata(model); - return status; -} + // assumes model has already been loaded before + common::Status InferenceSession::DoPostLoadProcessing(onnxruntime::Model& model) { + // TODO add other post load processing here + common::Status status = SaveModelMetadata(model); + return status; + } #endif -common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& model) { - VLOGS(*session_logger_, 1) << "Saving model metadata"; - const onnxruntime::Graph& graph = model.MainGraph(); + common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& model) { + VLOGS(*session_logger_, 1) << "Saving model metadata"; + const onnxruntime::Graph& graph = model.MainGraph(); - // save model metadata - model_metadata_.producer_name = model.ProducerName(); - model_metadata_.description = model.DocString(); - model_metadata_.domain = model.Domain(); - model_metadata_.version = model.ModelVersion(); - model_metadata_.custom_metadata_map = model.MetaData(); - model_metadata_.graph_name = graph.Name(); + // save model metadata + model_metadata_.producer_name = model.ProducerName(); + model_metadata_.description = model.DocString(); + model_metadata_.domain = model.Domain(); + model_metadata_.version = model.ModelVersion(); + model_metadata_.custom_metadata_map = model.MetaData(); + model_metadata_.graph_name = graph.Name(); - required_inputs_.clear(); - for (auto input : graph.GetInputs()) { - required_inputs_.insert(input->Name()); - } - - auto add_inputs = [this](const InputDefList& inputs) { - input_def_map_.clear(); - input_def_map_.reserve(inputs.size()); - for (auto elem : inputs) { - auto elem_type = utils::GetMLDataType(*elem); - auto elem_shape_proto = elem->Shape(); - input_def_map_.insert( - {elem->Name(), - InputDefMetaData( - elem, elem_type, - elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())}); - } - }; - - if (graph.CanOverrideInitializer()) { - // for IR 4 or higher it is optional to have a matching graph input for an initializer, and if one exists the - // initializer is explicitly overridable. - add_inputs(graph.GetInputsIncludingInitializers()); - } else { - // for IR < 4 we don't allow overriding initializers so that they can be treated as constant. exclude them from - // the list of valid inputs by just using the GetInputs() list. - add_inputs(graph.GetInputs()); - } - - // save outputs - const auto& outputs = graph.GetOutputs(); - output_def_list_ = outputs; // A direct copy of outputs - - model_output_names_.clear(); - model_output_names_.reserve(outputs.size()); - for (const auto& elem : outputs) { - model_output_names_.insert(elem->Name()); - } - - VLOGS(*session_logger_, 1) << "Done saving model metadata"; - return common::Status::OK(); -} - -// Create a Logger for a single execution if possible. Otherwise use the default logger. -// If a new logger is created, it will also be stored in new_run_logger, -// which must remain valid for the duration of the execution. -// If the default logger is used, new_run_logger will remain empty. -// The returned value should be used in the execution. -const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& run_options, - std::unique_ptr& new_run_logger) { - const logging::Logger* run_logger; - - // create a per-run logger if we can - if (logging_manager_ != nullptr) { - std::string run_log_id{session_options_.session_logid}; - - if (!session_options_.session_logid.empty() && !run_options.run_tag.empty()) { - run_log_id += ":"; + required_inputs_.clear(); + for (auto input : graph.GetInputs()) { + required_inputs_.insert(input->Name()); } - run_log_id += run_options.run_tag; + auto add_inputs = [this](const InputDefList& inputs) { + input_def_map_.clear(); + input_def_map_.reserve(inputs.size()); + for (auto elem : inputs) { + auto elem_type = utils::GetMLDataType(*elem); + auto elem_shape_proto = elem->Shape(); + input_def_map_.insert( + {elem->Name(), + InputDefMetaData( + elem, elem_type, + elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())}); + } + }; - logging::Severity severity = logging::Severity::kWARNING; - if (run_options.run_log_severity_level == -1) { - severity = session_logger_->GetSeverity(); + if (graph.CanOverrideInitializer()) { + // for IR 4 or higher it is optional to have a matching graph input for an initializer, and if one exists the + // initializer is explicitly overridable. + add_inputs(graph.GetInputsIncludingInitializers()); } else { - ORT_ENFORCE(run_options.run_log_severity_level >= 0 && - run_options.run_log_severity_level <= static_cast(logging::Severity::kFATAL), - "Invalid run log severity level. Not a valid onnxruntime::logging::Severity value: ", - run_options.run_log_severity_level); - severity = static_cast(run_options.run_log_severity_level); + // for IR < 4 we don't allow overriding initializers so that they can be treated as constant. exclude them from + // the list of valid inputs by just using the GetInputs() list. + add_inputs(graph.GetInputs()); } - new_run_logger = logging_manager_->CreateLogger(run_log_id, severity, false, run_options.run_log_verbosity_level); + // save outputs + const auto& outputs = graph.GetOutputs(); + output_def_list_ = outputs; // A direct copy of outputs - run_logger = new_run_logger.get(); - VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id; - } else { - // fallback to using default logger. this does NOT have any session or run specific id/tag in it - run_logger = session_logger_; - VLOGS(*run_logger, 1) << "Using default logger for run " << run_options.run_tag; + model_output_names_.clear(); + model_output_names_.reserve(outputs.size()); + for (const auto& elem : outputs) { + model_output_names_.insert(elem->Name()); + } + + VLOGS(*session_logger_, 1) << "Done saving model metadata"; + return common::Status::OK(); } - return *run_logger; -} + // Create a Logger for a single execution if possible. Otherwise use the default logger. + // If a new logger is created, it will also be stored in new_run_logger, + // which must remain valid for the duration of the execution. + // If the default logger is used, new_run_logger will remain empty. + // The returned value should be used in the execution. + const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& run_options, + std::unique_ptr& new_run_logger) { + const logging::Logger* run_logger; -void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { - // create logger for session, using provided logging manager if possible - if (logging_manager != nullptr) { - logging::Severity severity = logging::Severity::kWARNING; - if (session_options_.session_log_severity_level == -1) { - severity = logging::LoggingManager::DefaultLogger().GetSeverity(); + // create a per-run logger if we can + if (logging_manager_ != nullptr) { + std::string run_log_id{session_options_.session_logid}; + + if (!session_options_.session_logid.empty() && !run_options.run_tag.empty()) { + run_log_id += ":"; + } + + run_log_id += run_options.run_tag; + + logging::Severity severity = logging::Severity::kWARNING; + if (run_options.run_log_severity_level == -1) { + severity = session_logger_->GetSeverity(); + } else { + ORT_ENFORCE(run_options.run_log_severity_level >= 0 && + run_options.run_log_severity_level <= static_cast(logging::Severity::kFATAL), + "Invalid run log severity level. Not a valid onnxruntime::logging::Severity value: ", + run_options.run_log_severity_level); + severity = static_cast(run_options.run_log_severity_level); + } + + new_run_logger = logging_manager_->CreateLogger(run_log_id, severity, false, run_options.run_log_verbosity_level); + + run_logger = new_run_logger.get(); + VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id; } else { - ORT_ENFORCE(session_options_.session_log_severity_level >= 0 && - session_options_.session_log_severity_level <= static_cast(logging::Severity::kFATAL), - "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", - session_options_.session_log_severity_level); - severity = static_cast(session_options_.session_log_severity_level); + // fallback to using default logger. this does NOT have any session or run specific id/tag in it + run_logger = session_logger_; + VLOGS(*run_logger, 1) << "Using default logger for run " << run_options.run_tag; } - owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false, - session_options_.session_log_verbosity_level); - session_logger_ = owned_session_logger_.get(); - } else { - session_logger_ = &logging::LoggingManager::DefaultLogger(); + return *run_logger; + } + + void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { + // create logger for session, using provided logging manager if possible + if (logging_manager != nullptr) { + logging::Severity severity = logging::Severity::kWARNING; + if (session_options_.session_log_severity_level == -1) { + severity = logging::LoggingManager::DefaultLogger().GetSeverity(); + } else { + ORT_ENFORCE(session_options_.session_log_severity_level >= 0 && + session_options_.session_log_severity_level <= static_cast(logging::Severity::kFATAL), + "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", + session_options_.session_log_severity_level); + severity = static_cast(session_options_.session_log_severity_level); + } + + owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false, + session_options_.session_log_verbosity_level); + session_logger_ = owned_session_logger_.get(); + } else { + session_logger_ = &logging::LoggingManager::DefaultLogger(); + } } -} #if !defined(ORT_MINIMAL_BUILD) -// Registers all the predefined transformers with transformer manager -void InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - const std::vector& custom_list) { - auto add_transformers = [&](TransformerLevel level) { - // Generate and register transformers for level - auto transformers_to_register = - optimizer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, - *execution_providers_.Get(onnxruntime::kCpuExecutionProvider), - custom_list); - for (auto& entry : transformers_to_register) { - transformer_manager.Register(std::move(entry), level); - } - }; + // Registers all the predefined transformers with transformer manager + void InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + const std::vector& custom_list) { + auto add_transformers = [&](TransformerLevel level) { + // Generate and register transformers for level + auto transformers_to_register = + optimizer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, + *execution_providers_.Get(onnxruntime::kCpuExecutionProvider), + custom_list); + for (auto& entry : transformers_to_register) { + transformer_manager.Register(std::move(entry), level); + } + }; - ORT_ENFORCE(graph_optimization_level <= TransformerLevel::MaxLevel, - "Exceeded max transformer level. Current level is set to " + - std::to_string(static_cast(graph_optimization_level))); + ORT_ENFORCE(graph_optimization_level <= TransformerLevel::MaxLevel, + "Exceeded max transformer level. Current level is set to " + + std::to_string(static_cast(graph_optimization_level))); - for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { - TransformerLevel level = static_cast(i); - if ((graph_optimization_level >= level) || !custom_list.empty()) { - add_transformers(level); + for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { + TransformerLevel level = static_cast(i); + if ((graph_optimization_level >= level) || !custom_list.empty()) { + add_transformers(level); + } } } -} #endif // !defined(ORT_MINIMAL_BUILD) -common::Status InferenceSession::WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) { - if (timeout_in_ms > 0) { - ORT_NOT_IMPLEMENTED(__FUNCTION__, "timeout_in_ms >0 is not supported"); // TODO + common::Status InferenceSession::WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) { + if (timeout_in_ms > 0) { + ORT_NOT_IMPLEMENTED(__FUNCTION__, "timeout_in_ms >0 is not supported"); // TODO + } + p_executor_done->Wait(); + + return Status::OK(); } - p_executor_done->Wait(); - return Status::OK(); -} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { + ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); + } -SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { - ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); -} + InferenceSession* SessionIOBinding::GetInferenceSession() { + return sess_; + } -InferenceSession* SessionIOBinding::GetInferenceSession() { - return sess_; -} - -IOBinding* SessionIOBinding::Get() { - return binding_.get(); -} + IOBinding* SessionIOBinding::Get() { + return binding_.get(); + } } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e00adfc65e..9b55c8683a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1089,6 +1089,8 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc") .def_readwrite("execution_mode", &PySessionOptions::execution_mode, R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc") + .def_readwrite("execution_order", &SessionOptions::execution_order, + R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc") .def_property( "graph_optimization_level", [](const PySessionOptions* options) -> GraphOptimizationLevel { diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index f91a8e5e32..307c0b0f49 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -231,12 +231,12 @@ TEST_F(GraphTest, SimpleUnique) { std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(std::move(m), model, nullptr, *logger_)); } - + TEST_F(GraphTest, UnusedValueInfoSerializes) { ModelProto m; m.set_ir_version(4); ImportOpset(m, "", 11); - GraphProto& g = *m.mutable_graph(); + GraphProto& g = *m.mutable_graph(); NodeProto* node = g.add_node(); *node->add_input() = "x"; *node->add_output() = "sum"; @@ -633,9 +633,6 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) { // node_5 (Merge) // | - std::unordered_map, std::vector>> - expected_node_name_to_input_output_args; - TypeProto tensor_int32; tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); @@ -655,29 +652,24 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) { inputs.push_back(&input_arg1); outputs.push_back(&output_arg1); - expected_node_name_to_input_output_args["node_1"] = {inputs, outputs}; graph.AddNode("node_1", "Identity_Fake", "node 1", inputs, outputs); inputs[0] = &input_arg2; outputs[0] = &output_arg2; - expected_node_name_to_input_output_args["node_2"] = {inputs, outputs}; graph.AddNode("node_2", "Identity_Fake", "node 2", inputs, outputs); inputs[0] = &output_arg2; outputs[0] = &output_arg3; - expected_node_name_to_input_output_args["node_3"] = {inputs, outputs}; graph.AddNode("node_3", "Identity_Fake", "node 3", inputs, outputs); inputs[0] = &output_arg1; outputs[0] = &output_arg4; - expected_node_name_to_input_output_args["node_4"] = {inputs, outputs}; graph.AddNode("node_4", "Identity_Fake", "node 4", inputs, outputs); inputs.resize(2); inputs[0] = &output_arg4; inputs[1] = &output_arg3; outputs[0] = &output_arg5; - expected_node_name_to_input_output_args["node_5"] = {inputs, outputs}; graph.AddNode("node_5", "Merge_Fake", "node 3", inputs, outputs); auto status = graph.Resolve(); @@ -700,6 +692,223 @@ TEST_F(GraphTest, GraphConstruction_CheckInputNodeOrderMaintained) { } } +TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_CompressDecompress) { + Model model("graph_1", false, *logger_); + auto& graph = model.MainGraph(); + + /* + | + node_0 (Identity) + / \ + node_1 (Identity) compress (pri = LOCAL_HIGH) + | | + node_4 (Identity) decompress (pri = LOCAL_LOW) + \ / + node_5 (Merge) + | + */ + + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32); + auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32); + auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32); + auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32); + auto& output_arg3 = graph.GetOrCreateNodeArg("node_3_out_1", &tensor_int32); + auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32); + auto& output_arg5 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32); + + graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0}); + graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1}); + + auto& compress_node = graph.AddNode("compress", "Identity_Fake", "compress node", {&output_arg0}, {&output_arg2}); + compress_node.SetPriority(static_cast(ExecutionPriority::LOCAL_HIGH)); + + auto& decompress_node = graph.AddNode("decompress", "Identity_Fake", "decompress node", {&output_arg2}, {&output_arg3}); + decompress_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + + graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4}); + graph.AddNode("node_5", "Merge_Fake", "node 3", {&output_arg4, &output_arg3}, {&output_arg5}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + GraphViewer graph_viewer(graph); + + // PRIORITY_BASED order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const std::vector expected_priority_based_order = + {"node_0", "compress", "node_1", "node_4", "decompress", "node_5"}; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong."; + } + } + + // TOPOLOGICAL order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); + const std::vector expected_topological_order = { + "node_0", "node_1", "node_4", "compress", "decompress", "node_5"}; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_topological_order[i]) << "Priority based execution order is wrong."; + } + } +} + +TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_Recompute) { + Model model("graph_1", false, *logger_); + auto& graph = model.MainGraph(); + + /* + | + node_0 (Identity) + / \ + node_1 (Identity) recompute_node_1 (pri = LOCAL_LOW) + | | + node_4 (Identity) | + \ / + node_1_grad (Merge) + | + */ + + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32); + auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32); + auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32); + auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32); + auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32); + auto& output_arg5 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32); + + graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0}); + graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1}); + + auto& recompute_node = graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2}); + recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + + graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4}); + graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg4, &output_arg2}, {&output_arg5}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + GraphViewer graph_viewer(graph); + + // PRIORITY_BASED order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const std::vector expected_priority_based_order = + {"node_0", "node_1", "node_4", "recompute_node_1", "node_1_grad"}; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong."; + } + } +} + +TEST_F(GraphTest, GraphConstruction_PriorityBasedTopologicalSort_MultiLayerRecompute) { + Model model("graph_1", false, *logger_); + auto& graph = model.MainGraph(); + + /* + | + node_0 (Identity) + / \ + node_1 (Identity) \ + | \ \ + node_2 (Identity) \ \ + | \ \ \ + node_3 (Identity) \ \ \ + | \ \ \ \ + loss (Identity) \ \ \ \ + | | \ \ \ + 1 | | \ \ + \ / | \ | + loss_grad recom_node_3 | | + \ / | | + node_3_grad recom_node_2 | + \ / | + node_2_grad recom_node_1 + \ / + node_1_grad + | + */ + + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + // FW graph + auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in", &tensor_int32); + auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out", &tensor_int32); + auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out", &tensor_int32); + auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out", &tensor_int32); + auto& output_arg3 = graph.GetOrCreateNodeArg("node_3_out", &tensor_int32); + auto& output_loss = graph.GetOrCreateNodeArg("loss_out", &tensor_int32); + + graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0}); + graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1}); + graph.AddNode("node_2", "Identity_Fake", "node 2", {&output_arg1}, {&output_arg2}); + graph.AddNode("node_3", "Identity_Fake", "node 3", {&output_arg2}, {&output_arg3}); + graph.AddNode("loss", "Identity_Fake", "loss node", {&output_arg3}, {&output_loss}); + + // Recompute graph + auto& recomputed_arg3 = graph.GetOrCreateNodeArg("node_3_out_recomputed", &tensor_int32); + auto& recomputed_arg2 = graph.GetOrCreateNodeArg("node_2_out_recomputed", &tensor_int32); + auto& recomputed_arg1 = graph.GetOrCreateNodeArg("node_1_out_recomputed", &tensor_int32); + + auto& recompute_node3 = graph.AddNode("node_3_recompute", "Identity_Fake", "node 3 recompute", {&output_arg2}, {&recomputed_arg3}); + auto& recompute_node2 = graph.AddNode("node_2_recompute", "Identity_Fake", "node 2 recompute", {&output_arg1}, {&recomputed_arg2}); + auto& recompute_node1 = graph.AddNode("node_1_recompute", "Identity_Fake", "node 1 recompute", {&output_arg0}, {&recomputed_arg1}); + recompute_node3.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + recompute_node2.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + recompute_node1.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); + + // BW Graph + auto& gradient_start = graph.GetOrCreateNodeArg("gradient_start", &tensor_int32); + auto& loss_grad_output = graph.GetOrCreateNodeArg("loss_grad_output", &tensor_int32); + auto& node_3_grad_output = graph.GetOrCreateNodeArg("node_3_grad_output", &tensor_int32); + auto& node_2_grad_output = graph.GetOrCreateNodeArg("node_2_grad_output", &tensor_int32); + auto& node_1_grad_output = graph.GetOrCreateNodeArg("node_1_grad_output", &tensor_int32); + + graph.AddNode("loss_grad", "Merge_Fake", "loss gradient", {&gradient_start, &output_arg3}, {&loss_grad_output}); + graph.AddNode("node_3_grad", "Merge_Fake", "node 3 gradient", {&loss_grad_output, &recomputed_arg3}, {&node_3_grad_output}); + graph.AddNode("node_2_grad", "Merge_Fake", "node 2 gradient", {&node_3_grad_output, &recomputed_arg2}, {&node_2_grad_output}); + graph.AddNode("node_1_grad", "Merge_Fake", "node 1 gradient", {&node_2_grad_output, &recomputed_arg1}, {&node_1_grad_output}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + GraphViewer graph_viewer(graph); + + // PRIORITY_BASED order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + const std::vector expected_priority_based_order = { + "node_0", + "node_1", + "node_2", + "node_3", + "loss", + "loss_grad", + "node_3_recompute", + "node_3_grad", + "node_2_recompute", + "node_2_grad", + "node_1_recompute", + "node_1_grad", + }; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_priority_based_order[i]) << "Priority based execution order is wrong."; + } + } +} + TEST_F(GraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) { Model model("graph_1", false, *logger_); auto& graph = model.MainGraph(); diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index 19dbcca950..a851353413 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -12,7 +12,6 @@ #include "core/optimizer/rule_based_graph_transformer.h" using namespace ONNX_NAMESPACE; -using namespace std; namespace onnxruntime { namespace training { @@ -20,8 +19,8 @@ namespace training { using namespace common; GradientGraphBuilder::GradientGraphBuilder(Graph* graph, - const unordered_set& y_node_arg_names, - const unordered_set& x_node_arg_names, + const std::unordered_set& y_node_arg_names, + const std::unordered_set& x_node_arg_names, const std::string& loss_node_arg_name, const GradientGraphConfiguration& gradient_graph_config, const logging::Logger& logger) @@ -61,6 +60,11 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, y_nodes_.insert(node); } + reachable_nodes_ = ReverseBFS(y_nodes_); + + std::string unreachable_nodes; + + // building x_nodes_ for (const auto& name : x_node_arg_names) { const NodeArg* node_arg = graph->GetNodeArg(name); if (!node_arg) { @@ -68,21 +72,29 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, } x_node_args_.insert(node_arg); - vector nodes = graph_->GetConsumerNodes(name); + std::vector nodes = graph_->GetConsumerNodes(name); if (nodes.empty()) { ORT_THROW(name, " couldn't find the consumer node."); } - string grad_arg_name = GradientBuilderBase::GradientName(name); - pending_[grad_arg_name] = static_cast(nodes.size()); + std::string grad_arg_name = GradientBuilderBase::GradientName(name); + pending_[grad_arg_name] = 0; - x_nodes_.insert(nodes.begin(), nodes.end()); + for (const Node* node : nodes) { + if (IsReachable(node)) { + pending_[grad_arg_name] += 1; + x_nodes_.insert(node); + } else { + unreachable_nodes.append(node->Name() + ", "); + } + } } + LOGS(logger_, WARNING) << "Following nodes are unreachable for gradient back propagation: " << unreachable_nodes; } -NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) { +NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) const { NodeSet visited(nodes); - deque queue(nodes.begin(), nodes.end()); + std::deque queue(nodes.begin(), nodes.end()); while (!queue.empty()) { const Node* n = queue.front(); @@ -106,13 +118,13 @@ NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) { return visited; } -Status GradientGraphBuilder::CheckNodeArgsReachable(const NodeSet& reachable_nodes) { +Status GradientGraphBuilder::CheckNodeArgsReachable() const { for (const NodeArg* node_arg : x_node_args_) { auto nodes = graph_->GetConsumerNodes(node_arg->Name()); bool reachable = false; for (const Node* node : nodes) { - if (reachable_nodes.find(node) != reachable_nodes.end()) { + if (IsReachable(node)) { reachable = true; break; } @@ -141,14 +153,13 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init gradient_graph_defs.AddInitializers({tensor_proto}); } - NodeSet reachable_nodes = ReverseBFS(y_nodes_); - - ORT_RETURN_IF_ERROR(CheckNodeArgsReachable(reachable_nodes)); + ORT_RETURN_IF_ERROR(CheckNodeArgsReachable()); // Going forward to figure out which node_args need backprop-ed. - deque queue(x_nodes_.begin(), x_nodes_.end()); + std::deque queue(x_nodes_.begin(), x_nodes_.end()); NodeSet visited(x_nodes_); - unordered_set visited_node_args = x_node_args_; + + std::unordered_set visited_node_args = x_node_args_; visited_node_args.insert(y_node_args_.begin(), y_node_args_.end()); while (!queue.empty()) { @@ -158,7 +169,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init for (auto edge_it = node->OutputEdgesBegin(); edge_it != node->OutputEdgesEnd(); ++edge_it) { const Node& next_node = edge_it->GetNode(); - if (reachable_nodes.find(&next_node) == reachable_nodes.end()) continue; + if (!IsReachable(&next_node)) continue; auto it = STOP_GRADIENT_EDGES.find(next_node.OpType()); if (it != STOP_GRADIENT_EDGES.end() && it->second.count(edge_it->GetDstArgIndex())) { @@ -168,7 +179,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init } const NodeArg* node_arg = node->OutputDefs()[edge_it->GetSrcArgIndex()]; - string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name()); + std::string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name()); pending_[grad_node_arg_name] += 1; @@ -185,7 +196,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init // visited_node_args are the node_args involved for (auto node : visited) { //TODO: might not need two sets, the union of them might be enough - unordered_set input_args_need_grad, output_args_need_grad; + std::unordered_set input_args_need_grad, output_args_need_grad; for (auto arg : node->InputDefs()) { if (visited_node_args.find(arg) != visited_node_args.end()) { input_args_need_grad.insert(arg->Name()); @@ -205,7 +216,7 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init auto found = pending_.find(arg.name); if (found != pending_.end() && found->second > 1) { auto idx = gradients_to_accumulate_[arg].size(); - string indexed_arg_name = arg.name + "_" + to_string(idx); + std::string indexed_arg_name = arg.name + "_" + to_string(idx); gradients_to_accumulate_[arg].push_back(ArgDef(indexed_arg_name, arg.type_proto)); arg.name = indexed_arg_name; diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 0713dc2768..53cbd147f0 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -89,6 +89,7 @@ class GradientGraphBuilder { NodeSet y_nodes_; NodeSet x_nodes_; + NodeSet reachable_nodes_; Graph* graph_; @@ -117,14 +118,22 @@ class GradientGraphBuilder { @param nodes Starting nodes for ReverseBFS @returns All the nodes visited during ReverseBFS */ - NodeSet ReverseBFS(const NodeSet& nodes); + NodeSet ReverseBFS(const NodeSet& nodes) const; /** Check if 'x_node_args_' are reachable from 'y_node_args_' for computing the partial derivative @param reachable_nodes All the nodes reachable from the 'y_node_args_' @returns OK if all 'x_node_args_' are reachable, else an ONNXRUNTIME INVALID_ARGUMENT status */ - Status CheckNodeArgsReachable(const NodeSet& reachable_nodes); + + Status CheckNodeArgsReachable() const; + + /** + Check if node is reachable from the 'y_node_args_' + **/ + bool IsReachable(const Node* node) const { + return reachable_nodes_.find(node) != reachable_nodes_.end(); + } }; } // namespace training diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index f6e65927d0..0a87220c6b 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -663,9 +663,8 @@ IMPLEMENT_GRADIENT_BUILDER(GetReshapeGradient) { } } return std::vector{ - NodeDef("ReshapeGrad", - {I(0), GO(0)}, - {GI(0)})}; + NodeDef("Shape", {I(0)}, {IA("x_shape")}), + NodeDef("Reshape", {GO(0), IA("x_shape")}, {GI(0)})}; } IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_builder_base.h b/orttraining/orttraining/core/graph/gradient_builder_base.h index da2b58eaa9..cf8a496f69 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_base.h +++ b/orttraining/orttraining/core/graph/gradient_builder_base.h @@ -95,6 +95,15 @@ class GradientBuilderBase { // i-th output of forward op ArgDef O(const size_t i) const { ORT_ENFORCE(i < node_->OutputDefs().size()); + + const std::string& name = node_->OutputDefs()[i]->Name(); + const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name)); + if (recomputed_nodearg) { + const Node* producer_node = graph_->GetProducerNode(name); + LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name(); + return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto()); + } + return ArgDef(node_->OutputDefs()[i]->Name(), node_->OutputDefs()[i]->TypeAsProto()); } diff --git a/orttraining/orttraining/core/optimizer/dropout_recompute.cc b/orttraining/orttraining/core/optimizer/dropout_recompute.cc new file mode 100644 index 0000000000..1135fb4578 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/dropout_recompute.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/dropout_recompute.h" +#include "orttraining/core/graph/recompute_graph_utils.h" + +namespace onnxruntime { + +Node& InsertDropoutRecompute(Graph& graph, Node& node, bool use_original_input) { + NodeArg* input = node.MutableInputDefs()[0]; + if (!use_original_input) { + auto& recomputed_input = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(input->Name()), + input->TypeAsProto()); + input = &recomputed_input; + } + + const auto& output = node.OutputDefs()[0]; + auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), + output->TypeAsProto()); + + Node& recompute_node = graph.AddNode(node.Name() + "_recompute", + "DropoutGrad", + "Recompute of " + node.Name(), + { + input, // X + node.MutableOutputDefs()[1], // mask + node.MutableInputDefs()[1], // ratio + node.MutableInputDefs()[2] // training_mode + + }, + {&recomputed_output}, + {}, + kMSDomain); + + return recompute_node; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/dropout_recompute.h b/orttraining/orttraining/core/optimizer/dropout_recompute.h new file mode 100644 index 0000000000..c624e5194d --- /dev/null +++ b/orttraining/orttraining/core/optimizer/dropout_recompute.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/graph.h" + +namespace onnxruntime { + +Node& InsertDropoutRecompute(Graph& graph, Node& node, bool use_original_input); + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 7e2575e4ad..4b449ba3ce 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -43,6 +43,7 @@ #include "orttraining/core/optimizer/localized_recompute.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" +#include "orttraining/core/optimizer/transformer_layer_recompute.h" namespace onnxruntime { namespace training { @@ -73,10 +74,10 @@ std::vector> GeneratePreTrainingTransformers( rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); - if (config.gelu_checkpoint) { + if (config.gelu_recompute) { rule_transformer->Register(make_unique()); } - if (config.attn_dropout_checkpoint) { + if (config.attn_dropout_recompute) { rule_transformer->Register(make_unique()); } @@ -104,6 +105,10 @@ std::vector> GeneratePreTrainingTransformers( horizontal_parallel_size, compatible_eps)); } transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); + + if (config.transformer_layer_recompute) { + transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); + } } break; case TransformerLevel::Level2: { diff --git a/orttraining/orttraining/core/optimizer/localized_recompute.cc b/orttraining/orttraining/core/optimizer/localized_recompute.cc index 7a52f69805..02164fa0dc 100644 --- a/orttraining/orttraining/core/optimizer/localized_recompute.cc +++ b/orttraining/orttraining/core/optimizer/localized_recompute.cc @@ -4,6 +4,7 @@ #include "core/graph/graph_utils.h" #include "orttraining/core/graph/recompute_graph_utils.h" #include "orttraining/core/optimizer/localized_recompute.h" +#include "orttraining/core/optimizer/dropout_recompute.h" using namespace ONNX_NAMESPACE; @@ -23,13 +24,15 @@ Status GeluRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), output->TypeAsProto()); - graph.AddNode(node.Name() + "_recompute", - node.OpType(), - "Recompute of " + node.Name(), - {node.MutableInputDefs()[0]}, - {&recomputed_output}, - &node.GetAttributes(), - node.Domain()); + Node& recompute_node = graph.AddNode(node.Name() + "_recompute", + node.OpType(), + "Recompute of " + node.Name(), + {node.MutableInputDefs()[0]}, + {&recomputed_output}, + &node.GetAttributes(), + node.Domain()); + + recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; return Status::OK(); @@ -46,23 +49,8 @@ bool AttentionDropoutRecompute::SatisfyCondition(const Graph& /*graph*/, const N } Status AttentionDropoutRecompute::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const { - const auto& output = node.OutputDefs()[0]; - - auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), - output->TypeAsProto()); - - graph.AddNode(node.Name() + "_recompute", - "DropoutGrad", // Reusing DropoutGrad as the recompute op - "Recompute of " + node.Name(), - { - node.MutableInputDefs()[0], // X - node.MutableOutputDefs()[1], // mask - node.MutableInputDefs()[1], // ratio - node.MutableInputDefs()[2] // training_mode - }, - {&recomputed_output}, - {}, - kMSDomain); + Node& recompute_node = InsertDropoutRecompute(graph, node, /*use_original_input*/ true); + recompute_node.SetPriority(static_cast(ExecutionPriority::LOCAL_LOW)); rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; return Status::OK(); diff --git a/orttraining/orttraining/core/optimizer/transformer_layer_recompute.cc b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.cc new file mode 100644 index 0000000000..c43192a129 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.cc @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/transformer_layer_recompute.h" +#include "orttraining/core/optimizer/dropout_recompute.h" +#include "orttraining/core/graph/recompute_graph_utils.h" +#include "core/common/common.h" + +#include + +namespace onnxruntime { +Status TransformerLayerRecompute::IdentifyTransformerLayerEdges( + const Graph& graph, + std::vector>& start_end_edges, + const logging::Logger& logger) const { + const std::unordered_set gelu_ops{"Gelu", "BiasGelu", "FastGelu"}; + const std::unordered_set dropout_ops{"Dropout", "BiasDropout"}; + const std::unordered_set layernorm_ops{"LayerNormalization", "SkipLayerNormalization"}; + + std::vector layer_start_edges, layer_end_edges; + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + for (auto node_index : node_topology_list) { + auto& node = *graph.GetNode(node_index); + + // Look for start of a transformer layer + if ((layernorm_ops.find(node.OpType()) != layernorm_ops.end() || + dropout_ops.find(node.OpType()) != dropout_ops.end()) && + node.GetOutputEdgesCount() == 4) { + layer_start_edges.push_back(node.OutputDefs()[0]); + } + + // Look for end of a transformer layer + if (gelu_ops.find(node.OpType()) != gelu_ops.end()) { + auto next_node = node.OutputNodesBegin(); + + while (next_node->OutputNodesBegin() != next_node->OutputNodesEnd() && + dropout_ops.find(next_node->OpType()) == dropout_ops.end()) { + next_node = next_node->OutputNodesBegin(); + } + + while (next_node->OutputNodesBegin() != next_node->OutputNodesEnd() && + layernorm_ops.find(next_node->OpType()) == layernorm_ops.end()) { + next_node = next_node->OutputNodesBegin(); + } + + if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { + layer_end_edges.push_back(next_node->OutputDefs()[0]); + } + } + } + + ORT_RETURN_IF_NOT(layer_start_edges.size() == layer_end_edges.size(), + "Number of start and end edges doesn't match!, #start=", layer_start_edges.size(), + ", #end=", layer_end_edges.size()); + + start_end_edges.clear(); + + LOGS(logger, INFO) << "Found " << layer_start_edges.size() << " transformer layers."; + for (size_t i = 0; i < layer_start_edges.size(); ++i) { + start_end_edges.push_back({layer_start_edges[i], layer_end_edges[i]}); + LOGS(logger, INFO) << "Start edge: " << layer_start_edges[i]->Name() << " End edge: " << layer_end_edges[i]->Name(); + } + + return Status::OK(); +} + +namespace { + +typedef std::set NodeSet; + +NodeSet BFSFrom(const std::vector& start_nodes, bool reverse) { + NodeSet visited(start_nodes.begin(), start_nodes.end()); + std::deque queue(start_nodes.begin(), start_nodes.end()); + while (!queue.empty()) { + const Node* n = queue.front(); + queue.pop_front(); + + auto begin = reverse ? n->InputNodesBegin() : n->OutputNodesBegin(); + auto end = reverse ? n->InputNodesEnd() : n->OutputNodesEnd(); + + for (auto node_it = begin; node_it != end; ++node_it) { + const Node& node = *node_it; + if (visited.find(&node) == visited.end()) { + queue.push_back(&node); + visited.insert(&node); + } + } + } + return visited; +} +} // namespace + +std::vector TransformerLayerRecompute::NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const { + // Forward BFS from the start node + std::vector start_nodes = graph.GetConsumerNodes(start->Name()); + NodeSet fw_visited = BFSFrom(start_nodes, /*reverse*/ false); + + // Reverse BFS from the end node + const Node* end_node = graph.GetProducerNode(end->Name()); + NodeSet bw_visited = BFSFrom({end_node}, /*reverse*/ true); + + // Join fw_visited and bw_visited + std::vector intersect_nodes; + std::set_intersection(fw_visited.begin(), fw_visited.end(), + bw_visited.begin(), bw_visited.end(), + std::back_inserter(intersect_nodes), NodeCompare()); + + return intersect_nodes; +} + +void TransformerLayerRecompute::InsertRecomputeNodes(Graph& graph, const std::vector& nodes, int priority) const { + auto initializers = graph.GetAllInitializedTensors(); + + for (const Node* n : nodes) { + Node* node = graph.GetNode(n->Index()); + + // recomputed Dropout need to produce the same output as original dropout + // currently reusing original dropout's mask to achieve this + if (node->OpType() == "Dropout") { + const NodeArg* input = node->InputDefs()[0]; + const Node* p_node = graph.GetProducerNode(input->Name()); + + bool use_original_input = + initializers.find(input->Name()) != initializers.end() || + std::find(nodes.begin(), nodes.end(), p_node) == nodes.end(); + + Node& recompute_node = InsertDropoutRecompute(graph, *node, use_original_input); + recompute_node.SetPriority(priority); + continue; + } + + // prepare inputs for recompute node + std::vector recomputed_inputs; + for (NodeArg* input : node->MutableInputDefs()) { + const Node* p_node = graph.GetProducerNode(input->Name()); + + // do not duplicate initializers in recompute subgraph + if (initializers.find(input->Name()) != initializers.end() || + std::find(nodes.begin(), nodes.end(), p_node) == nodes.end()) { + recomputed_inputs.push_back(input); + } else { + auto& recomputed_input = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(input->Name()), + input->TypeAsProto()); + recomputed_inputs.push_back(&recomputed_input); + } + } + + // prepare ouputs for recompute node + std::vector recomputed_outputs; + for (NodeArg* output : node->MutableOutputDefs()) { + auto& recomputed_output = graph.GetOrCreateNodeArg(graph_utils::RecomputeName(output->Name()), + output->TypeAsProto()); + recomputed_outputs.push_back(&recomputed_output); + } + + Node& recompute_node = graph.AddNode(node->Name() + "_recompute", + node->OpType(), + "Recompute of " + node->Name(), + recomputed_inputs, + recomputed_outputs, + &node->GetAttributes(), + node->Domain()); + recompute_node.SetPriority(priority); + } + return; +} + +Status TransformerLayerRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& logger) const { + std::vector> start_end_edges; + + Status s = IdentifyTransformerLayerEdges(graph, start_end_edges, logger); + if (!s.IsOK()) { + modified = false; + return Status::OK(); + } + + // insert recompute nodes expect for the last transformer layer + // latter recompute layers have higher execution priorty + for (size_t i = 0; i < start_end_edges.size() - 1; ++i) { + std::vector nodes = NodesBetweenEdges(graph, start_end_edges[i].first, start_end_edges[i].second); + InsertRecomputeNodes(graph, nodes, static_cast(start_end_edges.size() - i)); + } + + modified = true; + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/transformer_layer_recompute.h b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.h new file mode 100644 index 0000000000..f3bb00ff20 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/transformer_layer_recompute.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +class TransformerLayerRecompute : public GraphTransformer { + public: + TransformerLayerRecompute(const std::unordered_set& compatible_execution_providers = {}) noexcept + : GraphTransformer("TransformerLayerRecompute", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + private: + Status IdentifyTransformerLayerEdges(const Graph& graph, + std::vector>& start_end_edges, + const logging::Logger& logger) const; + + std::vector NodesBetweenEdges(const Graph& graph, const NodeArg* start, const NodeArg* end) const; + + void InsertRecomputeNodes(Graph& graph, const std::vector& nodes, int priority) const; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 1f5e04a4a4..b49db28db2 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -217,11 +217,6 @@ Status TrainingSession::ConfigureForTraining( config_result.mixed_precision_config_result = mp_result; } - if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) { - ORT_IGNORE_RETURN_VALUE(Save( - config.model_with_loss_function_path.value(), SaveOption::NO_RELOAD)); - } - // We need to get trainable weights to prevent constant folding from them. This works well if trainable weights are passed from config. // For case we use GetTrainableModelInitializers to get trainable weights such as C++ frontend, it may get more initializers // than trainable weights here as it's before transformers. So the constant folding may miss some nodes we actually can fold. @@ -239,6 +234,11 @@ Status TrainingSession::ConfigureForTraining( ORT_RETURN_IF_ERROR(ApplyTransformationsToMainGraph(trainable_initializers, config.graph_transformer_config)); + if (IsRootNode(config) && config.model_with_loss_function_path.has_value()) { + ORT_IGNORE_RETURN_VALUE(Save( + config.model_with_loss_function_path.value(), SaveOption::NO_RELOAD)); + } + // derive actual set of weights to train std::unordered_set weight_names_to_train = !filtered_config_weight_names_to_train.empty() diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 38e184e83e..3fc97a306a 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -194,10 +194,12 @@ class TrainingSession : public InferenceSession { struct GraphTransformerConfiguration { // Whether to enable GELU approximation which is faster but produces different results. bool enable_gelu_approximation{false}; - // Enable checkpointing of attention dropout to save memory - bool attn_dropout_checkpoint{false}; - // Enable checkpointing of Gelu activation output to save memory - bool gelu_checkpoint{false}; + // Enable recompute of attention dropout to save memory + bool attn_dropout_recompute{false}; + // Enable recompute of Gelu activation output to save memory + bool gelu_recompute{false}; + // Enable recompute of transformer layer ouput to save memory + bool transformer_layer_recompute{false}; }; GraphTransformerConfiguration graph_transformer_config{}; diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index 98db443c64..91a1aa1d88 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -167,9 +167,11 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet cxxopts::value()->default_value("true")) ("enable_gelu_approximation", "Specify whether to enable GELU approximation.", cxxopts::value()->default_value("true")) - ("attn_dropout_checkpoint", "Enable checkpointing of attention dropout to save memory.", + ("attn_dropout_recompute", "Enable checkpointing of attention dropout to save memory.", cxxopts::value()->default_value("false")) - ("gelu_checkpoint", "Enable checkpointing of Gelu activation output to save memory.", + ("gelu_recompute", "Enable checkpointing of Gelu activation output to save memory.", + cxxopts::value()->default_value("false")) + ("transformer_layer_recompute", "Enable checkpointing of transformer layer output to save memory.", cxxopts::value()->default_value("false")) ("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)", cxxopts::value()->default_value("false")); @@ -458,8 +460,9 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet } params.enable_gelu_approximation = flags["enable_gelu_approximation"].as(); - params.attn_dropout_checkpoint = flags["attn_dropout_checkpoint"].as(); - params.gelu_checkpoint = flags["gelu_checkpoint"].as(); + params.attn_dropout_recompute = flags["attn_dropout_recompute"].as(); + params.gelu_recompute = flags["gelu_recompute"].as(); + params.transformer_layer_recompute = flags["transformer_layer_recompute"].as(); ort_params.log_severity = static_cast(flags["ort_log_severity"].as()); ORT_RETURN_IF_NOT( diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 5c36156a11..85fab1c5e1 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -34,6 +34,7 @@ namespace training { static std::vector overrides = {}; static SessionOptions SESSION_OPTION = { ExecutionMode::ORT_SEQUENTIAL, //execution_mode + ExecutionOrder::PRIORITY_BASED, //execution_order false, //enable_profiling ORT_TSTR(""), //optimized_model_filepath true, //enable_mem_pattern @@ -183,8 +184,9 @@ Status TrainingRunner::Initialize() { { TrainingSession::TrainingConfiguration::GraphTransformerConfiguration gt_config{}; gt_config.enable_gelu_approximation = params_.enable_gelu_approximation; - gt_config.attn_dropout_checkpoint = params_.attn_dropout_checkpoint; - gt_config.gelu_checkpoint = params_.gelu_checkpoint; + gt_config.attn_dropout_recompute = params_.attn_dropout_recompute; + gt_config.gelu_recompute = params_.gelu_recompute; + gt_config.transformer_layer_recompute = params_.transformer_layer_recompute; config.graph_transformer_config = gt_config; } diff --git a/orttraining/orttraining/models/runner/training_runner.h b/orttraining/orttraining/models/runner/training_runner.h index b69b30fe7c..28024f1f6c 100644 --- a/orttraining/orttraining/models/runner/training_runner.h +++ b/orttraining/orttraining/models/runner/training_runner.h @@ -175,10 +175,11 @@ class TrainingRunner { // Enable GELU approximation bool enable_gelu_approximation = false; // Enable checkpointing of attention dropout to save memory - bool attn_dropout_checkpoint = false; + bool attn_dropout_recompute = false; // Enable checkpointing of Gelu activation output to save memory - bool gelu_checkpoint = false; - + bool gelu_recompute = false; + // Enable checkpointing of transformer layer output to save memory + bool transformer_layer_recompute = false; // Use invertible layernorm grad bool use_invertible_layernorm_grad = false; };