diff --git a/cmake/horovod/CMakeLists.txt b/cmake/horovod/CMakeLists.txt index 6c9c428252..528ca8c55e 100644 --- a/cmake/horovod/CMakeLists.txt +++ b/cmake/horovod/CMakeLists.txt @@ -6,12 +6,6 @@ set(CMAKE_CXX_STANDARD 11) set(THIRD_PARTY_ROOT ${HOROVOD_ROOT}/third_party) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -B /data/anaconda/envs/ort/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -fPIC -DHAVE_MPI") - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=strict-aliasing -Wno-error=unused-parameter -Wno-error=ignored-qualifiers -Wno-error=maybe-uninitialized") - -add_definitions(-DEIGEN_MPL2_ONLY=1 -DHAVE_CUDA=1 -DHAVE_NCCL=1 -DHOROVOD_GPU_ALLREDUCE='N' -DHOROVOD_GPU_ALLGATHER='M' -DHOROVOD_GPU_BROADCAST='M') - file(GLOB_RECURSE horovod_common_src "${HOROVOD_ROOT}/horovod/common/mpi/mpi_context.cc" "${HOROVOD_ROOT}/horovod/common/mpi/mpi_controller.cc" @@ -46,9 +40,45 @@ file(GLOB_RECURSE horovod_common_src add_library(horovod ${horovod_common_src}) -target_compile_options(horovod PRIVATE ${MPI_CXX_COMPILE_FLAGS}) -target_compile_options(horovod PRIVATE -march=native) -target_compile_options(horovod PRIVATE -Wno-error=sign-compare) -target_compile_options(horovod PRIVATE -Wno-reorder) -target_compile_options(horovod PRIVATE -Wno-unused-variable) -target_include_directories(horovod PRIVATE ${THIRD_PARTY_ROOT}/eigen ${THIRD_PARTY_ROOT}/lbfgs/include ${THIRD_PARTY_ROOT}/boost/assert/include ${THIRD_PARTY_ROOT}/boost/config/include ${THIRD_PARTY_ROOT}/boost/core/include ${THIRD_PARTY_ROOT}/boost/detail/include ${THIRD_PARTY_ROOT}/boost/iterator/include ${THIRD_PARTY_ROOT}/boost/lockfree/include ${THIRD_PARTY_ROOT}/boost/mpl/include ${THIRD_PARTY_ROOT}/boost/parameter/include ${THIRD_PARTY_ROOT}/boost/predef/include ${THIRD_PARTY_ROOT}/boost/preprocessor/include ${THIRD_PARTY_ROOT}/boost/static_assert/include ${THIRD_PARTY_ROOT}/boost/type_traits/include ${THIRD_PARTY_ROOT}/boost/utility/include ${THIRD_PARTY_ROOT}/flatbuffers/include /usr/local/cuda/include ${MPI_CXX_INCLUDE_PATH}) +target_compile_definitions(horovod PRIVATE + EIGEN_MPL2_ONLY=1 + HAVE_CUDA=1 + HAVE_NCCL=1 + HAVE_MPI + HOROVOD_GPU_ALLREDUCE='N' + HOROVOD_GPU_ALLGATHER='M' + HOROVOD_GPU_BROADCAST='M') + +target_compile_options(horovod PRIVATE + -Wl,--sysroot=/ + -fwrapv + -fPIC + ${MPI_CXX_COMPILE_FLAGS} + -march=native + -Wall + -Wno-error=ignored-qualifiers + -Wno-error=sign-compare + -Wno-error=strict-aliasing + -Wno-error=unused-parameter + -Wno-reorder + -Wno-unused-variable) + +target_include_directories(horovod PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ${MPI_CXX_INCLUDE_PATH} + ${REPO_ROOT}/cmake/external/eigen # use ORT's eigen instead of ${THIRD_PARTY_ROOT}/eigen + ${THIRD_PARTY_ROOT}/lbfgs/include + ${THIRD_PARTY_ROOT}/boost/assert/include + ${THIRD_PARTY_ROOT}/boost/config/include + ${THIRD_PARTY_ROOT}/boost/core/include + ${THIRD_PARTY_ROOT}/boost/detail/include + ${THIRD_PARTY_ROOT}/boost/iterator/include + ${THIRD_PARTY_ROOT}/boost/lockfree/include + ${THIRD_PARTY_ROOT}/boost/mpl/include + ${THIRD_PARTY_ROOT}/boost/parameter/include + ${THIRD_PARTY_ROOT}/boost/predef/include + ${THIRD_PARTY_ROOT}/boost/preprocessor/include + ${THIRD_PARTY_ROOT}/boost/static_assert/include + ${THIRD_PARTY_ROOT}/boost/type_traits/include + ${THIRD_PARTY_ROOT}/boost/utility/include + ${THIRD_PARTY_ROOT}/flatbuffers/include) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 7ab1d6cdb8..12087aa7e0 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -786,17 +786,11 @@ class Graph { void SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto); const Node* GetProducerNode(const std::string& node_arg_name) const { - auto iter = node_arg_to_producer_node_.find(node_arg_name); - - if (iter != node_arg_to_producer_node_.end()) { - auto node_index = iter->second; - return GetNode(node_index); - } - return nullptr; + return GetProducerNodeImpl(*this, node_arg_name); } Node* GetMutableProducerNode(const std::string& node_arg_name) { - return const_cast(GetProducerNode(node_arg_name)); + return GetProducerNodeImpl(*this, node_arg_name); } void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) { @@ -810,25 +804,11 @@ class Graph { } std::vector GetConsumerNodes(const std::string& node_arg_name) const { - std::vector results; - auto iter = node_arg_to_consumer_nodes_.find(node_arg_name); - if (iter != node_arg_to_consumer_nodes_.end()) { - for (auto node_index : iter->second) { - results.push_back(GetNode(node_index)); - } - } - return results; + return GetConsumerNodesImpl(*this, node_arg_name); } std::vector GetMutableConsumerNodes(const std::string& node_arg_name) { - std::vector results; - auto iter = node_arg_to_consumer_nodes_.find(node_arg_name); - if (iter != node_arg_to_consumer_nodes_.end()) { - for (auto node_index : iter->second) { - results.push_back(GetNode(node_index)); - } - } - return results; + return GetConsumerNodesImpl(*this, node_arg_name); } void UpdateConsumerNodes(const std::string& node_arg_name, const std::vector& nodes) { @@ -1044,6 +1024,33 @@ class Graph { void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const; + template + static auto GetProducerNodeImpl( + TInstance& instance, const std::string& node_arg_name) + -> decltype(instance.GetNode(0)) { + auto iter = instance.node_arg_to_producer_node_.find(node_arg_name); + if (iter != instance.node_arg_to_producer_node_.end()) { + auto node_index = iter->second; + return instance.GetNode(node_index); + } + return nullptr; + } + + template + static auto GetConsumerNodesImpl( + TInstance& instance, const std::string& node_arg_name) + -> std::vector { + std::vector results; + auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name); + if (iter != instance.node_arg_to_consumer_nodes_.end()) { + results.reserve(iter->second.size()); + for (auto node_index : iter->second) { + results.push_back(instance.GetNode(node_index)); + } + } + return results; + } + const Model& owning_model_; // GraphProto to store name, version, initializer. diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index f3fb44620f..a3aea2c62c 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -109,7 +109,7 @@ class GraphViewer { bool IsSubgraph() const; /** Get the internal graph*/ - const Graph* GetGraph() const { return graph_; } + const Graph& GetGraph() const { return *graph_; } /** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 2c7560df19..6ec3d6a265 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -462,7 +462,7 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector& fetch_mlvalue if (to_be_executed_nodes_.find(fetch_mlvalue_idxs) != to_be_executed_nodes_.end()) return; - const Graph* graph = GetGraphViewer()->GetGraph(); + const Graph& graph = GetGraphViewer()->GetGraph(); // Get the nodes generating the fetches. std::vector nodes; @@ -475,12 +475,12 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector& fetch_mlvalue to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes)); return; } - auto ending_node = graph->GetProducerNode(node_arg_name); + auto ending_node = graph.GetProducerNode(node_arg_name); nodes.push_back(ending_node); } // Reversely traverse to get reachable nodes. - graph->ReverseDFSFrom( + graph.ReverseDFSFrom( nodes, {}, [&reachable_nodes](const Node* n) { reachable_nodes.insert(n->Index()); }); to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes)); } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 657e87f1b6..a82535c0ba 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2895,7 +2895,7 @@ Graph::~Graph() { std::ostream& operator<<(std::ostream& out, const Graph& graph) { out << "Inputs:\n"; for (auto* x : graph.GetInputs()) { - out << " " << x->Name() << " : " << *x->Type() << std::endl; + out << " " << x->Name() << " : " << *x->Type() << "\n"; } out << "Nodes:\n"; for (auto& node : graph.Nodes()) { @@ -2913,11 +2913,11 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) { } out << ", "; } - out << std::endl; + out << "\n"; } out << "Outputs:\n"; for (auto* x : graph.GetOutputs()) { - out << " " << x->Name() << " : " << *x->Type() << std::endl; + out << " " << x->Name() << " : " << *x->Type() << "\n"; } return out; }