diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 8fae7ae8b0..08ff82cee0 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -36,7 +36,10 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, streams_[exec_queue_id])); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault])); + // Copy only if the two addresses are different. + if (dst_data != src_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault])); + } } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); diff --git a/orttraining/orttraining/test/gradient/event_op_test.cc b/orttraining/orttraining/test/gradient/event_op_test.cc new file mode 100644 index 0000000000..881ed4fa5d --- /dev/null +++ b/orttraining/orttraining/test/gradient/event_op_test.cc @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "test/common/tensor_op_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/test_random_seed.h" +#include "test/util/include/default_providers.h" + +#include "onnx/defs/attr_proto_util.h" + +namespace onnxruntime { +namespace test { + +// Run GPU op for GPU build. Otherwise, run GPU op. +void run_provider_specific_optest(OpTester& tester) { + RunOptions run_option; +#ifdef USE_CUDA + std::vector> providers; + providers.push_back(DefaultCudaExecutionProvider()); +#else + std::vector> providers; + providers.push_back(DefaultCpuExecutionProvider()); +#endif + tester.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + std::unordered_set(), + &run_option, + &providers); +} + +void record_event(int64_t event_id) { + OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); + test_record.AddInput("EventIdentifier", {}, {event_id}); + test_record.AddInput("InputSignal", {}, {true}); + test_record.AddOutput("OutputSignal", {}, {true}); + run_provider_specific_optest(test_record); +} + +void record_event_multiple_inputs_and_outputs(int64_t event_id) { + OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); + test_record.AddInput("EventIdentifier", {}, {event_id}); + test_record.AddInput("InputSignal", {}, {true}); + test_record.AddInput("Input1", {3}, {9.4f, 1.7f, 3.6f}); + test_record.AddInput("Input2", {1}, {1.6f}); + test_record.AddOutput("OutputSignal", {}, {true}); + test_record.AddOutput("Output1", {3}, {9.4f, 1.7f, 3.6f}); + test_record.AddOutput("Output2", {1}, {1.6f}); + run_provider_specific_optest(test_record); +} + +void wait_event(int64_t event_id) { + OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); + test_wait.AddInput("EventIdentifier", {}, {event_id}); + test_wait.AddInput("InputSignal", {}, {true}); + test_wait.AddOutput("OutputSignal", {}, {true}); + run_provider_specific_optest(test_wait); +} + +void wait_event_multiple_inputs_and_outputs(int64_t event_id) { + OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); + test_wait.AddInput("EventIdentifier", {}, {event_id}); + test_wait.AddInput("InputSignal", {}, {true}); + test_wait.AddInput("Input1", {1}, {1.6f}); + test_wait.AddInput("Input2", {3}, {9.4f, 1.7f, 3.6f}); + test_wait.AddOutput("OutputSignal", {}, {true}); + test_wait.AddOutput("output1", {1}, {1.6f}); + test_wait.AddOutput("output2", {3}, {9.4f, 1.7f, 3.6f}); + run_provider_specific_optest(test_wait); +} + +TEST(Synchronization, RecordAndWaitEvent) { + const int64_t event_id = static_cast(1736); + record_event(event_id); + wait_event(event_id); +} + +TEST(Synchronization, WaitNullEvent) { + wait_event(-1); +} + +TEST(Synchronization, RecordAndWaitEventMultipleInputsAndOutputs) { + const int64_t event_id = static_cast(995); + record_event_multiple_inputs_and_outputs(event_id); + wait_event_multiple_inputs_and_outputs(event_id); +} + +TEST(Synchronization, WaitAndRecordEvent) { + const int64_t event_id = static_cast(1228); + std::thread waiting_thread(wait_event, event_id); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::thread recording_thread(record_event, event_id); + + waiting_thread.join(); + recording_thread.join(); +} + +TEST(Synchronization, WaitAndRecordEventMany) { + const size_t event_count = 16; + for (int i = 0; i < 8; ++i) { + std::thread thread_pool[2 * event_count]; + for (int j = 0; j < static_cast(event_count); ++j) { + thread_pool[j] = std::thread(wait_event, j); + thread_pool[j + event_count] = std::thread(record_event, j); + } + for (size_t j = 0; j < event_count; ++j) { + thread_pool[j].join(); + thread_pool[j + event_count].join(); + } + } +} + +} // namespace test +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 1b997547aa..d3201864d8 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1798,53 +1798,6 @@ TEST(GradientCheckerTest, SliceGrad) { } } -void record_event(int64_t event_id) { - OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain); - test_record.AddInput("EventIdentifier", {}, {event_id}); - test_record.AddInput("InputSignal", {}, {true}); - test_record.AddOutput("OutputSignal", {}, {true}); - test_record.Run(); -} - -void wait_event(int64_t event_id) { - OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain); - test_wait.AddInput("EventIdentifier", {}, {event_id}); - test_wait.AddInput("InputSignal", {}, {true}); - test_wait.AddOutput("OutputSignal", {}, {true}); - test_wait.Run(); -} - -TEST(Synchronization, RecordAndWaitEvent) { - const int64_t event_id = static_cast(1736); - record_event(event_id); - wait_event(event_id); -} - -TEST(Synchronization, WaitAndRecordEvent) { - const int64_t event_id = static_cast(1228); - std::thread waiting_thread(wait_event, event_id); - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - std::thread recording_thread(record_event, event_id); - - waiting_thread.join(); - recording_thread.join(); -} - -TEST(Synchronization, WaitAndRecordEventMany) { - const size_t event_count = 16; - for (int i = 0; i < 8; ++i) { - std::thread thread_pool[2 * event_count]; - for (int j = 0; j < static_cast(event_count); ++j) { - thread_pool[j] = std::thread(wait_event, j); - thread_pool[j + event_count] = std::thread(record_event, j); - } - for (size_t j = 0; j < event_count; ++j) { - thread_pool[j].join(); - thread_pool[j + event_count].join(); - } - } -} - TEST(GradientCheckerTest, ExpandGrad) { float max_error; GradientChecker gradient_checker; diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/common.h b/orttraining/orttraining/training_ops/cpu/controlflow/common.h new file mode 100644 index 0000000000..c6ff14689f --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/controlflow/common.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { + +template +std::vector> AliasRange(int start, int end) { + std::vector> aliases; + for (int i = start; i < end; i++) { + aliases.push_back(std::pair(input_start + i, output_start + i)); + } + return aliases; +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc index a39ab399c2..e9bb0c51c3 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.cc @@ -6,8 +6,16 @@ namespace onnxruntime { namespace contrib { +void OrtEventPool::CheckRange(const int64_t id) const { + ORT_ENFORCE( + id >= 0 && id < MaxNumItems, + "Got id ", id, + ". It should be in a range from 0 to ", + MaxNumItems, "."); +} + void OrtEventPool::SignalEvent(int64_t id) { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); std::unique_lock lock(pool_[id].mutex); pool_[id].signaled.store(true); lock.unlock(); @@ -15,18 +23,18 @@ void OrtEventPool::SignalEvent(int64_t id) { }; void OrtEventPool::ResetEvent(int64_t id) { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); std::lock_guard guard(pool_[id].mutex); pool_[id].signaled.store(false); }; bool OrtEventPool::QueryEvent(int64_t id) const { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); return pool_[id].signaled.load(); } void OrtEventPool::WaitEvent(int64_t id) const { - ORT_ENFORCE(id >= 0 && id < MaxNumItems); + CheckRange(id); std::unique_lock lock(pool_[id].mutex); pool_[id].cv.wait(lock, [this, id] { return pool_[id].signaled.load(); }); }; diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h index 511a5da234..68e9b95abb 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/event_pool.h @@ -34,6 +34,8 @@ class OrtEventPool final { OrtEventPool(const OrtEventPool&) = delete; OrtEventPool& operator=(const OrtEventPool&) = delete; + void CheckRange(const int64_t event_id) const; + struct Item { std::atomic signaled; mutable std::mutex mutex; @@ -43,9 +45,11 @@ class OrtEventPool final { signaled.store(false); } }; + enum { MaxNumItems = 4096 }; + Item pool_[MaxNumItems]; }; diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/record.cc b/orttraining/orttraining/training_ops/cpu/controlflow/record.cc index c36a23e9e7..eb305a7c27 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/record.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/record.cc @@ -1,19 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "record.h" +#include "orttraining/training_ops/cpu/controlflow/record.h" #include "core/providers/cpu/tensor/utils.h" +#include "common.h" namespace onnxruntime { namespace contrib { -template -std::vector> AliasRange(int start, int end) { - std::vector> aliases; - for (int i = start; i < end; i++) { - aliases.push_back(std::pair(input_start + i, output_start + i)); - } - return aliases; +void record_event_in_tensor(const Tensor& event_id_tensor) { + const int64_t event_id = *event_id_tensor.template Data(); + ORT_ENFORCE(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent"); + OrtEventPool::GetInstance().SignalEvent(event_id); } ONNX_OPERATOR_KERNEL_EX( @@ -28,12 +26,8 @@ ONNX_OPERATOR_KERNEL_EX( RecordEvent); Status RecordEvent::Compute(OpKernelContext* ctx) const { - const Tensor* event_id_tensor = ctx->Input(0); - const int64_t event_id = *event_id_tensor->template Data(); - - ORT_RETURN_IF_NOT(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent"); - - OrtEventPool::GetInstance().SignalEvent(event_id); + // Pass event-id tensor to event-recording helper function. + record_event_in_tensor(*ctx->Input(0)); for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { const Tensor* X = ctx->Input(i_out + 1); diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/record.h b/orttraining/orttraining/training_ops/cpu/controlflow/record.h index d4f02a612d..61fb28abba 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/record.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/record.h @@ -9,6 +9,9 @@ namespace onnxruntime { namespace contrib { +// Record the event ID stored in the input tensor. +void record_event_in_tensor(const Tensor& event_id_tensor); + class RecordEvent final : public OpKernel { public: RecordEvent(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc b/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc index b08ba65ddd..67f3128234 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/wait.cc @@ -1,19 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "wait.h" +#include "orttraining/training_ops/cpu/controlflow/wait.h" #include "core/providers/cpu/tensor/utils.h" +#include "common.h" namespace onnxruntime { namespace contrib { -template -std::vector> AliasRange(int start, int end) { - std::vector> aliases; - for (int i = start; i < end; i++) { - aliases.push_back(std::pair(input_start + i, output_start + i)); +void wait_event_in_tensor(const Tensor& event_id_tensor) { + const int64_t event_id = *event_id_tensor.template Data(); + // -1 is reserved to skip wait event + if (event_id != -1) { + // Wait the event to be recorded by a RecordEvent operator. + OrtEventPool::GetInstance().WaitEvent(event_id); + // BUGBUG: seems this would cause hang when a event is being waited more than once + // Destory the recorded event. + OrtEventPool::GetInstance().ResetEvent(event_id); } - return aliases; } ONNX_OPERATOR_KERNEL_EX( @@ -28,18 +32,7 @@ ONNX_OPERATOR_KERNEL_EX( WaitEvent); Status WaitEvent::Compute(OpKernelContext* ctx) const { - const Tensor* event_id_tensor = ctx->Input(0); - const int64_t event_id = *event_id_tensor->template Data(); - - // -1 is reserved to skip wait event - if (event_id != -1) { - // Wait the event to be recorded by a RecordEvent operator. - OrtEventPool::GetInstance().WaitEvent(event_id); - - // BUGBUG: seems this would cause hang when a event is being waited more than once - // Destory the recorded event. - OrtEventPool::GetInstance().ResetEvent(event_id); - } + wait_event_in_tensor(*ctx->Input(0)); for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { const Tensor* X = ctx->Input(i_out + 1); diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/wait.h b/orttraining/orttraining/training_ops/cpu/controlflow/wait.h index dff514880f..682aa4388e 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/wait.h +++ b/orttraining/orttraining/training_ops/cpu/controlflow/wait.h @@ -2,8 +2,6 @@ // Licensed under the MIT License. #pragma once -#include -#include #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "event_pool.h" @@ -11,6 +9,9 @@ namespace onnxruntime { namespace contrib { +// Wait for the event ID stored in the input tensor. +void wait_event_in_tensor(const Tensor& event_id_tensor); + class WaitEvent final : public OpKernel { public: WaitEvent(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/record.cc b/orttraining/orttraining/training_ops/cuda/controlflow/record.cc new file mode 100644 index 0000000000..2f8bb82a2e --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/record.cc @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/controlflow/record.h" +#include "core/providers/cpu/tensor/utils.h" +// Include RecordEvent's utility functions shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/common.h" +// Include event mechanism shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/event_pool.h" +#include "orttraining/training_ops/cpu/controlflow/record.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + RecordEvent, + kMSDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) /* Keep EventIdentifier in CPU */ + .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)), + RecordEvent); + +Status RecordEvent::ComputeInternal(OpKernelContext* ctx) const { + // Reuse CPU helper to record event because event tensor is a CPU tensor. + onnxruntime::contrib::record_event_in_tensor(*ctx->Input(0)); + + for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { + // This iteration copies (i-1)-th input to i-th output. + const Tensor* X = ctx->Input(i_out + 1); + const TensorShape& data_shape = X->Shape(); + Tensor* Y = ctx->Output(i_out, data_shape); + CopyTensor(*X, *Y); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/record.h b/orttraining/orttraining/training_ops/cuda/controlflow/record.h new file mode 100644 index 0000000000..0063af48f1 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/record.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" + +namespace onnxruntime { +namespace cuda { + +class RecordEvent final : public CudaKernel { +public: + RecordEvent(const OpKernelInfo& info) : CudaKernel(info) { } + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc b/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc new file mode 100644 index 0000000000..f90177c51a --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/wait.cc @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/controlflow/wait.h" +#include "core/providers/cpu/tensor/utils.h" +// Include RecordEvent's utility functions shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/common.h" +// Include event mechanism shared by CPU and GPU implementations. +#include "orttraining/training_ops/cpu/controlflow/event_pool.h" +#include "orttraining/training_ops/cpu/controlflow/wait.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + WaitEvent, + kMSDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) /* CPU variable */ + .TypeConstraint("TInt64", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)), + WaitEvent); + +Status WaitEvent::ComputeInternal(OpKernelContext* ctx) const { + // Reuse CPU helper to wait event because event tensor is a CPU tensor. + onnxruntime::contrib::wait_event_in_tensor(*ctx->Input(0)); + + for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { + // This iteration copies (i-1)-th input to i-th output. + const Tensor* X = ctx->Input(i_out + 1); + const TensorShape& data_shape = X->Shape(); + Tensor* Y = ctx->Output(i_out, data_shape); + CopyTensor(*X, *Y); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/controlflow/wait.h b/orttraining/orttraining/training_ops/cuda/controlflow/wait.h new file mode 100644 index 0000000000..b4a687fef6 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/controlflow/wait.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" + +namespace onnxruntime { +namespace cuda { + +class WaitEvent final : public CudaKernel { +public: + WaitEvent(const OpKernelInfo& info) : CudaKernel(info) { } + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime \ No newline at end of file diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index bcc5c4fa91..537864f1a1 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -131,6 +131,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv); #endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent); + #ifdef USE_NCCL class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllGather); @@ -263,6 +266,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif + BuildKernelCreateInfo, + BuildKernelCreateInfo, + #ifdef USE_NCCL BuildKernelCreateInfo, BuildKernelCreateInfo,