Oops, scope DistributedRunContext to just NCCL

This commit is contained in:
Ryan Hill 2021-05-13 10:34:28 -07:00
parent 374ae71739
commit 5f62d4bb3b
2 changed files with 5 additions and 1 deletions

View file

@ -62,10 +62,10 @@ Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, co
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h"
#include "orttraining/training_ops/cpu/tensor/split.h"
#include "orttraining/core/framework/distributed_run_context.h"
#endif
#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
#include "orttraining/training_ops/cuda/communication/nccl_service.h"
#include "orttraining/core/framework/distributed_run_context.h"
#endif
namespace ONNX_NAMESPACE {
@ -881,9 +881,11 @@ struct ProviderHostImpl : ProviderHost {
Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); }
Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); }
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
training::DistributedRunContext& GetDistributedRunContextInstance() override { return training::DistributedRunContext::GetInstance(); }
#endif
#endif
#endif
} provider_host_;
struct ProviderSharedLibrary {

View file

@ -789,9 +789,11 @@ struct ProviderHost {
virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) = 0;
virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0;
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
virtual training::DistributedRunContext& GetDistributedRunContextInstance() = 0;
#endif
#endif
#endif
};
#ifdef SHARED_PROVIDER