From 2ce90cff4c4ea04fb50d28aba3c2f5dbd4beaca5 Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Wed, 1 Apr 2020 10:36:16 -0700 Subject: [PATCH] PR comments (#3374) * PR comments * PR comments * PR comments * PR comments * PR comments * PR comments * PR comments Co-authored-by: Ethan Tao --- cmake/CMakeLists.txt | 1 + cmake/onnxruntime_framework.cmake | 3 -- .../core/framework/sequential_executor.cc | 30 ++++++------- onnxruntime/core/framework/session_state.cc | 20 ++++++--- onnxruntime/core/framework/session_state.h | 4 ++ .../core/optimizer/matmul_transpose_fusion.cc | 45 ++++++++++--------- 6 files changed, 57 insertions(+), 46 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8b883c5210..40cb1d361a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -796,6 +796,7 @@ if (onnxruntime_USE_DML) endif() if (onnxruntime_ENABLE_TRAINING) + add_compile_definitions(ENABLE_TRAINING) if (onnxruntime_USE_HOROVOD) if (WIN32) message( FATAL_ERROR "Horovod is not supported on Windows." ) diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index 75df20d4db..43e08cbe1e 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -30,6 +30,3 @@ if (WIN32) set_target_properties(onnxruntime_framework PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/ConfigureVisualStudioCodeAnalysis.props) endif() -if(onnxruntime_ENABLE_TRAINING) - target_compile_definitions(onnxruntime_framework PUBLIC ENABLE_TRAINING) -endif() diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 732d986f4b..aaf0333f30 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -187,12 +187,6 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ", node.Name()); - std::string node_name = node.Name(); - if (node_name.empty()) { - // Node name field is often blank in execution graph, so derive something meaningful for profile traces and logs. - node_name = node.OpType() + "_" + std::to_string(node_index); - } - #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT LARGE_INTEGER kernel_start; QueryPerformanceCounter(&kernel_start); @@ -241,22 +235,26 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: utils::DumpNodeInputs(op_kernel_context, p_op_kernel->Node()); #endif + const std::string node_name_for_profiling = [&]() -> std::string { + if (!is_profiler_enabled) return {}; + // Derive something meaningful for profile traces and logs if node name field is blank in execution graph + return node.Name().empty() ? MakeString(node.OpType(), "_", node_index) : node.Name(); + }(); + if (is_profiler_enabled) { session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT, - node_name + "_fence_before", + node_name_for_profiling + "_fence_before", sync_time_begin, {{"op_name", p_op_kernel->KernelDef().OpName()}}); // call compute on the kernel - VLOGS(logger, 1) << "Computing kernel: " << node_name; + VLOGS(logger, 1) << "Computing kernel: " << node_name_for_profiling; kernel_begin_time = session_state.Profiler().StartTime(); - } - if (is_profiler_enabled) { // Calculate total input sizes for this operation. CalculateTotalInputSizes(&op_kernel_context, p_op_kernel, - input_activation_sizes,input_parameter_sizes, node_name); + input_activation_sizes,input_parameter_sizes, node_name_for_profiling); } #ifdef CONCURRENCY_VISUALIZER @@ -286,12 +284,12 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: if (is_profiler_enabled) { // Calculate total output sizes for this operation. - CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name); + CalculateTotalOutputSizes(&op_kernel_context, total_output_sizes, node_name_for_profiling); #if defined(TRACE_EXECUTION) // Trace execution step. const Node& node = p_op_kernel->Node(); - std::cout << "Executed op kernel node " << node_name + std::cout << "Executed op kernel node " << node_name_for_profiling << " Index=" << node.Index() << " OpType=" << node.OpType() << " Name=" << node.Name() @@ -302,7 +300,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: #endif session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT, - node_name + "_kernel_time", + node_name_for_profiling + "_kernel_time", kernel_begin_time, // Log additional operation args / info. { @@ -356,7 +354,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: #endif if (is_profiler_enabled) { session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT, - node_name + "_fence_after", + node_name_for_profiling + "_fence_after", sync_time_begin, {{"op_name", p_op_kernel->KernelDef().OpName()}}); } @@ -366,7 +364,7 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: #endif // free ml-values corresponding to this node - VLOGS(logger, 1) << "Releasing node ML values after computing kernel: " << node_name; + VLOGS(logger, 1) << "Releasing node ML values."; ORT_RETURN_IF_ERROR(ReleaseNodeMLValues(frame, seq_exec_plan, node_exec_plan, logger)); } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6ec3d6a265..8ef57b2023 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -136,6 +136,7 @@ const std::unordered_map& SessionState::GetConstantInitializedTen return constant_initialized_tensors_; } +#ifdef ENABLE_TRAINING Status SessionState::GetInitializedTensors( const std::unordered_set& interested_weights, bool allow_missing_weights, NameMLValMap& retrieved_weights) const { @@ -162,6 +163,7 @@ NameMLValMap SessionState::GetInitializedTensors(const std::unordered_set& feeds, std::unordered_map& out) { for (const auto* input : graph.GetInputs()) { @@ -280,6 +283,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector>& input_shapes, const std::vector& feed_mlvalue_idxs) const { @@ -459,7 +463,9 @@ const NodeIndexInfo& SessionState::GetNodeIndexInfo() const { } void SessionState::UpdateToBeExecutedNodes(const std::vector& fetch_mlvalue_idxs) { - if (to_be_executed_nodes_.find(fetch_mlvalue_idxs) != to_be_executed_nodes_.end()) + std::vector sorted_idxs = fetch_mlvalue_idxs; + std::sort(sorted_idxs.begin(), sorted_idxs.end()); + if (to_be_executed_nodes_.find(sorted_idxs) != to_be_executed_nodes_.end()) return; const Graph& graph = GetGraphViewer()->GetGraph(); @@ -471,10 +477,8 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector& fetch_mlvalue for (auto idx : fetch_mlvalue_idxs) { std::string node_arg_name; - if (!this->GetOrtValueNameIdxMap().GetName(idx, node_arg_name).IsOK()) { - to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes)); - return; - } + const auto status = this->GetOrtValueNameIdxMap().GetName(idx, node_arg_name); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); auto ending_node = graph.GetProducerNode(node_arg_name); nodes.push_back(ending_node); } @@ -482,12 +486,14 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector& fetch_mlvalue // Reversely traverse to get reachable nodes. graph.ReverseDFSFrom( nodes, {}, [&reachable_nodes](const Node* n) { reachable_nodes.insert(n->Index()); }); - to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes)); + to_be_executed_nodes_.insert(std::make_pair(sorted_idxs, reachable_nodes)); } const std::unordered_set* SessionState::GetToBeExecutedNodes( const std::vector& fetch_mlvalue_idxs) const { - auto it = to_be_executed_nodes_.find(fetch_mlvalue_idxs); + std::vector sorted_idxs = fetch_mlvalue_idxs; + std::sort(sorted_idxs.begin(), sorted_idxs.end()); + auto it = to_be_executed_nodes_.find(sorted_idxs); return (it != to_be_executed_nodes_.end()) ? &it->second : nullptr; } diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 1db7e71144..473eb9acc6 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -116,6 +116,7 @@ class SessionState { */ const std::unordered_map& GetConstantInitializedTensors() const; +#ifdef ENABLE_TRAINING /** Get some initialized tensors (weights). @param interested_weights The names of the weights to retrieve. @@ -133,6 +134,7 @@ class SessionState { Any names in interested_weights with no corresponding weight are ignored. */ NameMLValMap GetInitializedTensors(const std::unordered_set& interested_weights) const; +#endif // execution plan void SetExecutionPlan(std::unique_ptr p_seq_exec_plan); @@ -249,10 +251,12 @@ class SessionState { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SessionState); +#ifdef ENABLE_TRAINING Status GeneratePatternGroupCache( const std::vector>& input_shape, const std::vector& feed_mlvalue_idxs, MemoryPatternGroup* output) const; +#endif // cache of the constructed kernels to avoid spending construction // time per executor diff --git a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc index fe4e5e48f1..b78138ea6b 100644 --- a/onnxruntime/core/optimizer/matmul_transpose_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_transpose_fusion.cc @@ -10,16 +10,21 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { -static std::pair IsInputTranspose(const Graph& graph, NodeArg& node_arg) { - const auto& trans_node = graph.GetProducerNode(node_arg.Name()); +static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) { + Node* trans_node = graph.GetMutableProducerNode(node_arg.Name()); if (trans_node == nullptr || trans_node->OpType() != "Transpose") { - return std::make_pair(false, nullptr); + return nullptr; + } + + // if the node has Graph output, skip it too + if (!graph.GetNodeOutputsInGraphOutputs(*trans_node).empty()) { + return nullptr; } auto perms = RetrieveValues(trans_node->GetAttributes().at("perm")); int64_t rank = perms.size(); if (rank < 2) { - return std::make_pair(false, nullptr); + return nullptr; } bool is_trans_on_last_two_dims = true; @@ -31,14 +36,14 @@ static std::pair IsInputTranspose(const Graph& graph, NodeArg& node } if (is_trans_on_last_two_dims) { - is_trans_on_last_two_dims = (int64_t)perms[rank - 2] == rank - 1 && (int64_t)perms[rank - 1] == rank - 2; + is_trans_on_last_two_dims = perms[rank - 2] == rank - 1 && perms[rank - 1] == rank - 2; } if (!is_trans_on_last_two_dims) { - return std::make_pair(false, nullptr); + return nullptr; } - return std::make_pair(true, const_cast(trans_node)); + return trans_node; } static size_t UpdateConsumerCount(Graph& graph, NodeArg* target, std::unordered_map& count_map) { @@ -72,27 +77,27 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ } NodeArg* left_input = node.MutableInputDefs()[0]; - auto left = IsInputTranspose(graph, *left_input); + auto left = GetTransposeNodeFromOutput(graph, *left_input); NodeArg* right_input = node.MutableInputDefs()[1]; - auto right = IsInputTranspose(graph, *right_input); + auto right = GetTransposeNodeFromOutput(graph, *right_input); - if (!left.first && !right.first) { + if (!left && !right) { continue; } - if (left.first) { + if (left) { size_t left_consumers = UpdateConsumerCount(graph, left_input, consumer_count); if (left_consumers == 0) - removed_nodes.push_front(left.second->Index()); - left_input = left.second->MutableInputDefs()[0]; + removed_nodes.push_front(left->Index()); + left_input = left->MutableInputDefs()[0]; } - if (right.first) { + if (right) { size_t right_consumers = UpdateConsumerCount(graph, right_input, consumer_count); if (right_consumers == 0) - removed_nodes.push_front(right.second->Index()); - right_input = right.second->MutableInputDefs()[0]; + removed_nodes.push_front(right->Index()); + right_input = right->MutableInputDefs()[0]; } const std::vector input_defs{left_input, right_input}; @@ -103,16 +108,16 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ "fused MatMul and Transpose ", input_defs, output_defs, {}, kMSDomain); - bool transpose_left = left.first; + bool transpose_left = (left != nullptr); if (node.OpType() == "TransposeMatMul") { transpose_left ^= static_cast(node.GetAttributes().at("transA").i()); } - bool transpose_right = right.first; + bool transpose_right = (right != nullptr); if (node.OpType() == "TransposeMatMul") { transpose_right ^= static_cast(node.GetAttributes().at("transB").i()); } - matmul_node.AddAttribute("transA", (int64_t)transpose_left); - matmul_node.AddAttribute("transB", (int64_t)transpose_right); + matmul_node.AddAttribute("transA", static_cast(transpose_left)); + matmul_node.AddAttribute("transB", static_cast(transpose_right)); // Assign provider to this new node. Provider should be same as the provider for old node. matmul_node.SetExecutionProviderType(node.GetExecutionProviderType());