diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index 89919d0dd1..803a95501d 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -40,6 +40,9 @@ endif() # Needed for the provider interface, as it includes training headers when training is enabled if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_framework PRIVATE ${ORTTRAINING_ROOT}) + if (onnxruntime_USE_NCCL OR onnxruntime_USE_MPI) + target_include_directories(onnxruntime_framework PUBLIC ${MPI_CXX_INCLUDE_DIRS}) + endif() endif() onnxruntime_add_include_to_target(onnxruntime_framework onnxruntime_common onnx onnx_proto protobuf::libprotobuf flatbuffers) set_target_properties(onnxruntime_framework PROPERTIES FOLDER "ONNXRuntime") diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index e462064477..554b3eea37 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -63,7 +63,7 @@ Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, co #include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h" #include "orttraining/training_ops/cpu/tensor/split.h" #endif -#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) +#if defined(USE_CUDA) && defined(ORT_USE_NCCL) #include "orttraining/training_ops/cuda/communication/nccl_service.h" #include "orttraining/core/framework/distributed_run_context.h" #endif @@ -883,7 +883,7 @@ struct ProviderHostImpl : ProviderHost { Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); } Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); } -#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) +#if defined(ORT_USE_NCCL) training::DistributedRunContext& GetDistributedRunContextInstance() override { return training::DistributedRunContext::GetInstance(); } #endif #endif diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 3c58150e9f..aa3304718f 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -40,6 +40,7 @@ struct DeleteOnUnloadPtr { #include #include #include +#include #include "onnx/common/stl_backports.h" #include "core/common/common.h" #include "core/common/const_pointer_container.h" diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index d2ad57e925..7261b5f1ea 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -791,7 +791,7 @@ struct ProviderHost { virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) = 0; virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0; -#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) +#if defined(ORT_USE_NCCL) virtual training::DistributedRunContext& GetDistributedRunContextInstance() = 0; #endif #endif