From 5f62d4bb3bcfda0ef747fff655595c2247758a13 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 13 May 2021 10:34:28 -0700 Subject: [PATCH] Oops, scope DistributedRunContext to just NCCL --- onnxruntime/core/framework/provider_bridge_ort.cc | 4 +++- .../core/providers/shared_library/provider_interfaces.h | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index fff9daa65d..09f665eb8a 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -62,10 +62,10 @@ Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, co #include "orttraining/training_ops/cpu/controlflow/yield.h" #include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h" #include "orttraining/training_ops/cpu/tensor/split.h" -#include "orttraining/core/framework/distributed_run_context.h" #endif #if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) #include "orttraining/training_ops/cuda/communication/nccl_service.h" +#include "orttraining/core/framework/distributed_run_context.h" #endif namespace ONNX_NAMESPACE { @@ -881,9 +881,11 @@ struct ProviderHostImpl : ProviderHost { Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); } Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); } +#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) training::DistributedRunContext& GetDistributedRunContextInstance() override { return training::DistributedRunContext::GetInstance(); } #endif #endif +#endif } provider_host_; struct ProviderSharedLibrary { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 3644b4d918..1993ffc311 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -789,9 +789,11 @@ struct ProviderHost { virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) = 0; virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0; +#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) virtual training::DistributedRunContext& GetDistributedRunContextInstance() = 0; #endif #endif +#endif }; #ifdef SHARED_PROVIDER