mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Oops, scope DistributedRunContext to just NCCL
This commit is contained in:
parent
374ae71739
commit
5f62d4bb3b
2 changed files with 5 additions and 1 deletions
|
|
@ -62,10 +62,10 @@ Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, co
|
|||
#include "orttraining/training_ops/cpu/controlflow/yield.h"
|
||||
#include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h"
|
||||
#include "orttraining/training_ops/cpu/tensor/split.h"
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
#endif
|
||||
#if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
|
||||
#include "orttraining/training_ops/cuda/communication/nccl_service.h"
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
#endif
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
|
|
@ -881,9 +881,11 @@ struct ProviderHostImpl : ProviderHost {
|
|||
Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); }
|
||||
Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); }
|
||||
|
||||
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
|
||||
training::DistributedRunContext& GetDistributedRunContextInstance() override { return training::DistributedRunContext::GetInstance(); }
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
} provider_host_;
|
||||
|
||||
struct ProviderSharedLibrary {
|
||||
|
|
|
|||
|
|
@ -789,9 +789,11 @@ struct ProviderHost {
|
|||
virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) = 0;
|
||||
virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0;
|
||||
|
||||
#if defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P)
|
||||
virtual training::DistributedRunContext& GetDistributedRunContextInstance() = 0;
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef SHARED_PROVIDER
|
||||
|
|
|
|||
Loading…
Reference in a new issue