mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Address master merge PR comments (#3348)
Address some comments from https://github.com/microsoft/onnxruntime/pull/3174. - https://github.com/microsoft/onnxruntime/pull/3174#discussion_r396855459 - https://github.com/microsoft/onnxruntime/pull/3174#discussion_r396855630 - https://github.com/microsoft/onnxruntime/pull/3174#discussion_r396857140 - https://github.com/microsoft/onnxruntime/pull/3174#discussion_r398094858 - https://github.com/microsoft/onnxruntime/pull/3174#issuecomment-599024924
This commit is contained in:
parent
d8f0a0f223
commit
fb2f97a002
5 changed files with 80 additions and 43 deletions
|
|
@ -6,12 +6,6 @@ set(CMAKE_CXX_STANDARD 11)
|
|||
|
||||
set(THIRD_PARTY_ROOT ${HOROVOD_ROOT}/third_party)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -B /data/anaconda/envs/ort/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -fPIC -DHAVE_MPI")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=strict-aliasing -Wno-error=unused-parameter -Wno-error=ignored-qualifiers -Wno-error=maybe-uninitialized")
|
||||
|
||||
add_definitions(-DEIGEN_MPL2_ONLY=1 -DHAVE_CUDA=1 -DHAVE_NCCL=1 -DHOROVOD_GPU_ALLREDUCE='N' -DHOROVOD_GPU_ALLGATHER='M' -DHOROVOD_GPU_BROADCAST='M')
|
||||
|
||||
file(GLOB_RECURSE horovod_common_src
|
||||
"${HOROVOD_ROOT}/horovod/common/mpi/mpi_context.cc"
|
||||
"${HOROVOD_ROOT}/horovod/common/mpi/mpi_controller.cc"
|
||||
|
|
@ -46,9 +40,45 @@ file(GLOB_RECURSE horovod_common_src
|
|||
|
||||
add_library(horovod ${horovod_common_src})
|
||||
|
||||
target_compile_options(horovod PRIVATE ${MPI_CXX_COMPILE_FLAGS})
|
||||
target_compile_options(horovod PRIVATE -march=native)
|
||||
target_compile_options(horovod PRIVATE -Wno-error=sign-compare)
|
||||
target_compile_options(horovod PRIVATE -Wno-reorder)
|
||||
target_compile_options(horovod PRIVATE -Wno-unused-variable)
|
||||
target_include_directories(horovod PRIVATE ${THIRD_PARTY_ROOT}/eigen ${THIRD_PARTY_ROOT}/lbfgs/include ${THIRD_PARTY_ROOT}/boost/assert/include ${THIRD_PARTY_ROOT}/boost/config/include ${THIRD_PARTY_ROOT}/boost/core/include ${THIRD_PARTY_ROOT}/boost/detail/include ${THIRD_PARTY_ROOT}/boost/iterator/include ${THIRD_PARTY_ROOT}/boost/lockfree/include ${THIRD_PARTY_ROOT}/boost/mpl/include ${THIRD_PARTY_ROOT}/boost/parameter/include ${THIRD_PARTY_ROOT}/boost/predef/include ${THIRD_PARTY_ROOT}/boost/preprocessor/include ${THIRD_PARTY_ROOT}/boost/static_assert/include ${THIRD_PARTY_ROOT}/boost/type_traits/include ${THIRD_PARTY_ROOT}/boost/utility/include ${THIRD_PARTY_ROOT}/flatbuffers/include /usr/local/cuda/include ${MPI_CXX_INCLUDE_PATH})
|
||||
target_compile_definitions(horovod PRIVATE
|
||||
EIGEN_MPL2_ONLY=1
|
||||
HAVE_CUDA=1
|
||||
HAVE_NCCL=1
|
||||
HAVE_MPI
|
||||
HOROVOD_GPU_ALLREDUCE='N'
|
||||
HOROVOD_GPU_ALLGATHER='M'
|
||||
HOROVOD_GPU_BROADCAST='M')
|
||||
|
||||
target_compile_options(horovod PRIVATE
|
||||
-Wl,--sysroot=/
|
||||
-fwrapv
|
||||
-fPIC
|
||||
${MPI_CXX_COMPILE_FLAGS}
|
||||
-march=native
|
||||
-Wall
|
||||
-Wno-error=ignored-qualifiers
|
||||
-Wno-error=sign-compare
|
||||
-Wno-error=strict-aliasing
|
||||
-Wno-error=unused-parameter
|
||||
-Wno-reorder
|
||||
-Wno-unused-variable)
|
||||
|
||||
target_include_directories(horovod PRIVATE
|
||||
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
|
||||
${MPI_CXX_INCLUDE_PATH}
|
||||
${REPO_ROOT}/cmake/external/eigen # use ORT's eigen instead of ${THIRD_PARTY_ROOT}/eigen
|
||||
${THIRD_PARTY_ROOT}/lbfgs/include
|
||||
${THIRD_PARTY_ROOT}/boost/assert/include
|
||||
${THIRD_PARTY_ROOT}/boost/config/include
|
||||
${THIRD_PARTY_ROOT}/boost/core/include
|
||||
${THIRD_PARTY_ROOT}/boost/detail/include
|
||||
${THIRD_PARTY_ROOT}/boost/iterator/include
|
||||
${THIRD_PARTY_ROOT}/boost/lockfree/include
|
||||
${THIRD_PARTY_ROOT}/boost/mpl/include
|
||||
${THIRD_PARTY_ROOT}/boost/parameter/include
|
||||
${THIRD_PARTY_ROOT}/boost/predef/include
|
||||
${THIRD_PARTY_ROOT}/boost/preprocessor/include
|
||||
${THIRD_PARTY_ROOT}/boost/static_assert/include
|
||||
${THIRD_PARTY_ROOT}/boost/type_traits/include
|
||||
${THIRD_PARTY_ROOT}/boost/utility/include
|
||||
${THIRD_PARTY_ROOT}/flatbuffers/include)
|
||||
|
|
|
|||
|
|
@ -786,17 +786,11 @@ class Graph {
|
|||
void SetNodeArgType(NodeArg& arg, const onnx::TypeProto& type_proto);
|
||||
|
||||
const Node* GetProducerNode(const std::string& node_arg_name) const {
|
||||
auto iter = node_arg_to_producer_node_.find(node_arg_name);
|
||||
|
||||
if (iter != node_arg_to_producer_node_.end()) {
|
||||
auto node_index = iter->second;
|
||||
return GetNode(node_index);
|
||||
}
|
||||
return nullptr;
|
||||
return GetProducerNodeImpl(*this, node_arg_name);
|
||||
}
|
||||
|
||||
Node* GetMutableProducerNode(const std::string& node_arg_name) {
|
||||
return const_cast<Node*>(GetProducerNode(node_arg_name));
|
||||
return GetProducerNodeImpl(*this, node_arg_name);
|
||||
}
|
||||
|
||||
void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) {
|
||||
|
|
@ -810,25 +804,11 @@ class Graph {
|
|||
}
|
||||
|
||||
std::vector<const Node*> GetConsumerNodes(const std::string& node_arg_name) const {
|
||||
std::vector<const Node*> results;
|
||||
auto iter = node_arg_to_consumer_nodes_.find(node_arg_name);
|
||||
if (iter != node_arg_to_consumer_nodes_.end()) {
|
||||
for (auto node_index : iter->second) {
|
||||
results.push_back(GetNode(node_index));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
return GetConsumerNodesImpl(*this, node_arg_name);
|
||||
}
|
||||
|
||||
std::vector<Node*> GetMutableConsumerNodes(const std::string& node_arg_name) {
|
||||
std::vector<Node*> results;
|
||||
auto iter = node_arg_to_consumer_nodes_.find(node_arg_name);
|
||||
if (iter != node_arg_to_consumer_nodes_.end()) {
|
||||
for (auto node_index : iter->second) {
|
||||
results.push_back(GetNode(node_index));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
return GetConsumerNodesImpl(*this, node_arg_name);
|
||||
}
|
||||
|
||||
void UpdateConsumerNodes(const std::string& node_arg_name, const std::vector<Node*>& nodes) {
|
||||
|
|
@ -1044,6 +1024,33 @@ class Graph {
|
|||
|
||||
void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const;
|
||||
|
||||
template <typename TInstance>
|
||||
static auto GetProducerNodeImpl(
|
||||
TInstance& instance, const std::string& node_arg_name)
|
||||
-> decltype(instance.GetNode(0)) {
|
||||
auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
|
||||
if (iter != instance.node_arg_to_producer_node_.end()) {
|
||||
auto node_index = iter->second;
|
||||
return instance.GetNode(node_index);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename TInstance>
|
||||
static auto GetConsumerNodesImpl(
|
||||
TInstance& instance, const std::string& node_arg_name)
|
||||
-> std::vector<decltype(instance.GetNode(0))> {
|
||||
std::vector<decltype(instance.GetNode(0))> results;
|
||||
auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
|
||||
if (iter != instance.node_arg_to_consumer_nodes_.end()) {
|
||||
results.reserve(iter->second.size());
|
||||
for (auto node_index : iter->second) {
|
||||
results.push_back(instance.GetNode(node_index));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
const Model& owning_model_;
|
||||
|
||||
// GraphProto to store name, version, initializer.
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class GraphViewer {
|
|||
bool IsSubgraph() const;
|
||||
|
||||
/** Get the internal graph*/
|
||||
const Graph* GetGraph() const { return graph_; }
|
||||
const Graph& GetGraph() const { return *graph_; }
|
||||
|
||||
/**
|
||||
returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
|
||||
|
|
|
|||
|
|
@ -462,7 +462,7 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue
|
|||
if (to_be_executed_nodes_.find(fetch_mlvalue_idxs) != to_be_executed_nodes_.end())
|
||||
return;
|
||||
|
||||
const Graph* graph = GetGraphViewer()->GetGraph();
|
||||
const Graph& graph = GetGraphViewer()->GetGraph();
|
||||
|
||||
// Get the nodes generating the fetches.
|
||||
std::vector<const Node*> nodes;
|
||||
|
|
@ -475,12 +475,12 @@ void SessionState::UpdateToBeExecutedNodes(const std::vector<int>& fetch_mlvalue
|
|||
to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes));
|
||||
return;
|
||||
}
|
||||
auto ending_node = graph->GetProducerNode(node_arg_name);
|
||||
auto ending_node = graph.GetProducerNode(node_arg_name);
|
||||
nodes.push_back(ending_node);
|
||||
}
|
||||
|
||||
// Reversely traverse to get reachable nodes.
|
||||
graph->ReverseDFSFrom(
|
||||
graph.ReverseDFSFrom(
|
||||
nodes, {}, [&reachable_nodes](const Node* n) { reachable_nodes.insert(n->Index()); });
|
||||
to_be_executed_nodes_.insert(std::make_pair(fetch_mlvalue_idxs, reachable_nodes));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2895,7 +2895,7 @@ Graph::~Graph() {
|
|||
std::ostream& operator<<(std::ostream& out, const Graph& graph) {
|
||||
out << "Inputs:\n";
|
||||
for (auto* x : graph.GetInputs()) {
|
||||
out << " " << x->Name() << " : " << *x->Type() << std::endl;
|
||||
out << " " << x->Name() << " : " << *x->Type() << "\n";
|
||||
}
|
||||
out << "Nodes:\n";
|
||||
for (auto& node : graph.Nodes()) {
|
||||
|
|
@ -2913,11 +2913,11 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) {
|
|||
}
|
||||
out << ", ";
|
||||
}
|
||||
out << std::endl;
|
||||
out << "\n";
|
||||
}
|
||||
out << "Outputs:\n";
|
||||
for (auto* x : graph.GetOutputs()) {
|
||||
out << " " << x->Name() << " : " << *x->Type() << std::endl;
|
||||
out << " " << x->Name() << " : " << *x->Type() << "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue