mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Support GPU Event Operators (#3653)
* Add GPU event operators to support in-place updates in gradient accumulator and optimizer for modifying the tensors passing through those event operators. * Address comment and polish code * Merge shared code between CPU and GPU kernels * Move event test to a new file * Address comments * Update onnxruntime/core/providers/cuda/gpu_data_transfer.cc
This commit is contained in:
parent
d06763ac1c
commit
72b38f0a8b
15 changed files with 317 additions and 87 deletions
|
|
@ -36,7 +36,10 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e
|
|||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, streams_[exec_queue_id]));
|
||||
} else if (src_device.Type() == OrtDevice::GPU) {
|
||||
// copying between GPU, this is non-blocking
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault]));
|
||||
// Copy only if the two addresses are different.
|
||||
if (dst_data != src_data) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault]));
|
||||
}
|
||||
} else {
|
||||
// copy from other CPU memory to GPU, this is blocking
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice));
|
||||
|
|
|
|||
122
orttraining/orttraining/test/gradient/event_op_test.cc
Normal file
122
orttraining/orttraining/test/gradient/event_op_test.cc
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <bitset>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/util/include/test_random_seed.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
||||
#include "onnx/defs/attr_proto_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
// Run GPU op for GPU build. Otherwise, run GPU op.
|
||||
void run_provider_specific_optest(OpTester& tester) {
|
||||
RunOptions run_option;
|
||||
#ifdef USE_CUDA
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> providers;
|
||||
providers.push_back(DefaultCudaExecutionProvider());
|
||||
#else
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> providers;
|
||||
providers.push_back(DefaultCpuExecutionProvider());
|
||||
#endif
|
||||
tester.Run(
|
||||
OpTester::ExpectResult::kExpectSuccess,
|
||||
"",
|
||||
std::unordered_set<std::string>(),
|
||||
&run_option,
|
||||
&providers);
|
||||
}
|
||||
|
||||
void record_event(int64_t event_id) {
|
||||
OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain);
|
||||
test_record.AddInput<int64_t>("EventIdentifier", {}, {event_id});
|
||||
test_record.AddInput<bool>("InputSignal", {}, {true});
|
||||
test_record.AddOutput<bool>("OutputSignal", {}, {true});
|
||||
run_provider_specific_optest(test_record);
|
||||
}
|
||||
|
||||
void record_event_multiple_inputs_and_outputs(int64_t event_id) {
|
||||
OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain);
|
||||
test_record.AddInput<int64_t>("EventIdentifier", {}, {event_id});
|
||||
test_record.AddInput<bool>("InputSignal", {}, {true});
|
||||
test_record.AddInput<float>("Input1", {3}, {9.4f, 1.7f, 3.6f});
|
||||
test_record.AddInput<float>("Input2", {1}, {1.6f});
|
||||
test_record.AddOutput<bool>("OutputSignal", {}, {true});
|
||||
test_record.AddOutput<float>("Output1", {3}, {9.4f, 1.7f, 3.6f});
|
||||
test_record.AddOutput<float>("Output2", {1}, {1.6f});
|
||||
run_provider_specific_optest(test_record);
|
||||
}
|
||||
|
||||
void wait_event(int64_t event_id) {
|
||||
OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain);
|
||||
test_wait.AddInput<int64_t>("EventIdentifier", {}, {event_id});
|
||||
test_wait.AddInput<bool>("InputSignal", {}, {true});
|
||||
test_wait.AddOutput<bool>("OutputSignal", {}, {true});
|
||||
run_provider_specific_optest(test_wait);
|
||||
}
|
||||
|
||||
void wait_event_multiple_inputs_and_outputs(int64_t event_id) {
|
||||
OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain);
|
||||
test_wait.AddInput<int64_t>("EventIdentifier", {}, {event_id});
|
||||
test_wait.AddInput<bool>("InputSignal", {}, {true});
|
||||
test_wait.AddInput<float>("Input1", {1}, {1.6f});
|
||||
test_wait.AddInput<float>("Input2", {3}, {9.4f, 1.7f, 3.6f});
|
||||
test_wait.AddOutput<bool>("OutputSignal", {}, {true});
|
||||
test_wait.AddOutput<float>("output1", {1}, {1.6f});
|
||||
test_wait.AddOutput<float>("output2", {3}, {9.4f, 1.7f, 3.6f});
|
||||
run_provider_specific_optest(test_wait);
|
||||
}
|
||||
|
||||
TEST(Synchronization, RecordAndWaitEvent) {
|
||||
const int64_t event_id = static_cast<int64_t>(1736);
|
||||
record_event(event_id);
|
||||
wait_event(event_id);
|
||||
}
|
||||
|
||||
TEST(Synchronization, WaitNullEvent) {
|
||||
wait_event(-1);
|
||||
}
|
||||
|
||||
TEST(Synchronization, RecordAndWaitEventMultipleInputsAndOutputs) {
|
||||
const int64_t event_id = static_cast<int64_t>(995);
|
||||
record_event_multiple_inputs_and_outputs(event_id);
|
||||
wait_event_multiple_inputs_and_outputs(event_id);
|
||||
}
|
||||
|
||||
TEST(Synchronization, WaitAndRecordEvent) {
|
||||
const int64_t event_id = static_cast<int64_t>(1228);
|
||||
std::thread waiting_thread(wait_event, event_id);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
std::thread recording_thread(record_event, event_id);
|
||||
|
||||
waiting_thread.join();
|
||||
recording_thread.join();
|
||||
}
|
||||
|
||||
TEST(Synchronization, WaitAndRecordEventMany) {
|
||||
const size_t event_count = 16;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
std::thread thread_pool[2 * event_count];
|
||||
for (int j = 0; j < static_cast<int>(event_count); ++j) {
|
||||
thread_pool[j] = std::thread(wait_event, j);
|
||||
thread_pool[j + event_count] = std::thread(record_event, j);
|
||||
}
|
||||
for (size_t j = 0; j < event_count; ++j) {
|
||||
thread_pool[j].join();
|
||||
thread_pool[j + event_count].join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1798,53 +1798,6 @@ TEST(GradientCheckerTest, SliceGrad) {
|
|||
}
|
||||
}
|
||||
|
||||
void record_event(int64_t event_id) {
|
||||
OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain);
|
||||
test_record.AddInput<int64_t>("EventIdentifier", {}, {event_id});
|
||||
test_record.AddInput<bool>("InputSignal", {}, {true});
|
||||
test_record.AddOutput<bool>("OutputSignal", {}, {true});
|
||||
test_record.Run();
|
||||
}
|
||||
|
||||
void wait_event(int64_t event_id) {
|
||||
OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain);
|
||||
test_wait.AddInput<int64_t>("EventIdentifier", {}, {event_id});
|
||||
test_wait.AddInput<bool>("InputSignal", {}, {true});
|
||||
test_wait.AddOutput<bool>("OutputSignal", {}, {true});
|
||||
test_wait.Run();
|
||||
}
|
||||
|
||||
TEST(Synchronization, RecordAndWaitEvent) {
|
||||
const int64_t event_id = static_cast<int64_t>(1736);
|
||||
record_event(event_id);
|
||||
wait_event(event_id);
|
||||
}
|
||||
|
||||
TEST(Synchronization, WaitAndRecordEvent) {
|
||||
const int64_t event_id = static_cast<int64_t>(1228);
|
||||
std::thread waiting_thread(wait_event, event_id);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
std::thread recording_thread(record_event, event_id);
|
||||
|
||||
waiting_thread.join();
|
||||
recording_thread.join();
|
||||
}
|
||||
|
||||
TEST(Synchronization, WaitAndRecordEventMany) {
|
||||
const size_t event_count = 16;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
std::thread thread_pool[2 * event_count];
|
||||
for (int j = 0; j < static_cast<int>(event_count); ++j) {
|
||||
thread_pool[j] = std::thread(wait_event, j);
|
||||
thread_pool[j + event_count] = std::thread(record_event, j);
|
||||
}
|
||||
for (size_t j = 0; j < event_count; ++j) {
|
||||
thread_pool[j].join();
|
||||
thread_pool[j + event_count].join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ExpandGrad) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <int input_start, int output_start>
|
||||
std::vector<std::pair<int, int>> AliasRange(int start, int end) {
|
||||
std::vector<std::pair<int, int>> aliases;
|
||||
for (int i = start; i < end; i++) {
|
||||
aliases.push_back(std::pair<int, int>(input_start + i, output_start + i));
|
||||
}
|
||||
return aliases;
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -6,8 +6,16 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
void OrtEventPool::CheckRange(const int64_t id) const {
|
||||
ORT_ENFORCE(
|
||||
id >= 0 && id < MaxNumItems,
|
||||
"Got id ", id,
|
||||
". It should be in a range from 0 to ",
|
||||
MaxNumItems, ".");
|
||||
}
|
||||
|
||||
void OrtEventPool::SignalEvent(int64_t id) {
|
||||
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
|
||||
CheckRange(id);
|
||||
std::unique_lock<std::mutex> lock(pool_[id].mutex);
|
||||
pool_[id].signaled.store(true);
|
||||
lock.unlock();
|
||||
|
|
@ -15,18 +23,18 @@ void OrtEventPool::SignalEvent(int64_t id) {
|
|||
};
|
||||
|
||||
void OrtEventPool::ResetEvent(int64_t id) {
|
||||
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
|
||||
CheckRange(id);
|
||||
std::lock_guard<std::mutex> guard(pool_[id].mutex);
|
||||
pool_[id].signaled.store(false);
|
||||
};
|
||||
|
||||
bool OrtEventPool::QueryEvent(int64_t id) const {
|
||||
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
|
||||
CheckRange(id);
|
||||
return pool_[id].signaled.load();
|
||||
}
|
||||
|
||||
void OrtEventPool::WaitEvent(int64_t id) const {
|
||||
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
|
||||
CheckRange(id);
|
||||
std::unique_lock<std::mutex> lock(pool_[id].mutex);
|
||||
pool_[id].cv.wait(lock, [this, id] { return pool_[id].signaled.load(); });
|
||||
};
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ class OrtEventPool final {
|
|||
OrtEventPool(const OrtEventPool&) = delete;
|
||||
OrtEventPool& operator=(const OrtEventPool&) = delete;
|
||||
|
||||
void CheckRange(const int64_t event_id) const;
|
||||
|
||||
struct Item {
|
||||
std::atomic<bool> signaled;
|
||||
mutable std::mutex mutex;
|
||||
|
|
@ -43,9 +45,11 @@ class OrtEventPool final {
|
|||
signaled.store(false);
|
||||
}
|
||||
};
|
||||
|
||||
enum {
|
||||
MaxNumItems = 4096
|
||||
};
|
||||
|
||||
Item pool_[MaxNumItems];
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "record.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/record.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <int input_start, int output_start>
|
||||
std::vector<std::pair<int, int>> AliasRange(int start, int end) {
|
||||
std::vector<std::pair<int, int>> aliases;
|
||||
for (int i = start; i < end; i++) {
|
||||
aliases.push_back(std::pair<int, int>(input_start + i, output_start + i));
|
||||
}
|
||||
return aliases;
|
||||
void record_event_in_tensor(const Tensor& event_id_tensor) {
|
||||
const int64_t event_id = *event_id_tensor.template Data<int64_t>();
|
||||
ORT_ENFORCE(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent");
|
||||
OrtEventPool::GetInstance().SignalEvent(event_id);
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
|
@ -28,12 +26,8 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
RecordEvent);
|
||||
|
||||
Status RecordEvent::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* event_id_tensor = ctx->Input<Tensor>(0);
|
||||
const int64_t event_id = *event_id_tensor->template Data<int64_t>();
|
||||
|
||||
ORT_RETURN_IF_NOT(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent");
|
||||
|
||||
OrtEventPool::GetInstance().SignalEvent(event_id);
|
||||
// Pass event-id tensor to event-recording helper function.
|
||||
record_event_in_tensor(*ctx->Input<Tensor>(0));
|
||||
|
||||
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
|
||||
const Tensor* X = ctx->Input<Tensor>(i_out + 1);
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
// Record the event ID stored in the input tensor.
|
||||
void record_event_in_tensor(const Tensor& event_id_tensor);
|
||||
|
||||
class RecordEvent final : public OpKernel {
|
||||
public:
|
||||
RecordEvent(const OpKernelInfo& info) : OpKernel(info) { }
|
||||
|
|
|
|||
|
|
@ -1,19 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "wait.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/wait.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <int input_start, int output_start>
|
||||
std::vector<std::pair<int, int>> AliasRange(int start, int end) {
|
||||
std::vector<std::pair<int, int>> aliases;
|
||||
for (int i = start; i < end; i++) {
|
||||
aliases.push_back(std::pair<int, int>(input_start + i, output_start + i));
|
||||
void wait_event_in_tensor(const Tensor& event_id_tensor) {
|
||||
const int64_t event_id = *event_id_tensor.template Data<int64_t>();
|
||||
// -1 is reserved to skip wait event
|
||||
if (event_id != -1) {
|
||||
// Wait the event to be recorded by a RecordEvent operator.
|
||||
OrtEventPool::GetInstance().WaitEvent(event_id);
|
||||
// BUGBUG: seems this would cause hang when a event is being waited more than once
|
||||
// Destory the recorded event.
|
||||
OrtEventPool::GetInstance().ResetEvent(event_id);
|
||||
}
|
||||
return aliases;
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
|
@ -28,18 +32,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
WaitEvent);
|
||||
|
||||
Status WaitEvent::Compute(OpKernelContext* ctx) const {
|
||||
const Tensor* event_id_tensor = ctx->Input<Tensor>(0);
|
||||
const int64_t event_id = *event_id_tensor->template Data<int64_t>();
|
||||
|
||||
// -1 is reserved to skip wait event
|
||||
if (event_id != -1) {
|
||||
// Wait the event to be recorded by a RecordEvent operator.
|
||||
OrtEventPool::GetInstance().WaitEvent(event_id);
|
||||
|
||||
// BUGBUG: seems this would cause hang when a event is being waited more than once
|
||||
// Destory the recorded event.
|
||||
OrtEventPool::GetInstance().ResetEvent(event_id);
|
||||
}
|
||||
wait_event_in_tensor(*ctx->Input<Tensor>(0));
|
||||
|
||||
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
|
||||
const Tensor* X = ctx->Input<Tensor>(i_out + 1);
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "event_pool.h"
|
||||
|
|
@ -11,6 +9,9 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
// Wait for the event ID stored in the input tensor.
|
||||
void wait_event_in_tensor(const Tensor& event_id_tensor);
|
||||
|
||||
class WaitEvent final : public OpKernel {
|
||||
public:
|
||||
WaitEvent(const OpKernelInfo& info) : OpKernel(info) { }
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/controlflow/record.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
// Include RecordEvent's utility functions shared by CPU and GPU implementations.
|
||||
#include "orttraining/training_ops/cpu/controlflow/common.h"
|
||||
// Include event mechanism shared by CPU and GPU implementations.
|
||||
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/record.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
RecordEvent,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) /* Keep EventIdentifier in CPU */
|
||||
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)),
|
||||
RecordEvent);
|
||||
|
||||
Status RecordEvent::ComputeInternal(OpKernelContext* ctx) const {
|
||||
// Reuse CPU helper to record event because event tensor is a CPU tensor.
|
||||
onnxruntime::contrib::record_event_in_tensor(*ctx->Input<Tensor>(0));
|
||||
|
||||
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
|
||||
// This iteration copies (i-1)-th input to i-th output.
|
||||
const Tensor* X = ctx->Input<Tensor>(i_out + 1);
|
||||
const TensorShape& data_shape = X->Shape();
|
||||
Tensor* Y = ctx->Output(i_out, data_shape);
|
||||
CopyTensor(*X, *Y);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class RecordEvent final : public CudaKernel {
|
||||
public:
|
||||
RecordEvent(const OpKernelInfo& info) : CudaKernel(info) { }
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/controlflow/wait.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
// Include RecordEvent's utility functions shared by CPU and GPU implementations.
|
||||
#include "orttraining/training_ops/cpu/controlflow/common.h"
|
||||
// Include event mechanism shared by CPU and GPU implementations.
|
||||
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/wait.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
WaitEvent,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(0) /* CPU variable */
|
||||
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)),
|
||||
WaitEvent);
|
||||
|
||||
Status WaitEvent::ComputeInternal(OpKernelContext* ctx) const {
|
||||
// Reuse CPU helper to wait event because event tensor is a CPU tensor.
|
||||
onnxruntime::contrib::wait_event_in_tensor(*ctx->Input<Tensor>(0));
|
||||
|
||||
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
|
||||
// This iteration copies (i-1)-th input to i-th output.
|
||||
const Tensor* X = ctx->Input<Tensor>(i_out + 1);
|
||||
const TensorShape& data_shape = X->Shape();
|
||||
Tensor* Y = ctx->Output(i_out, data_shape);
|
||||
CopyTensor(*X, *Y);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
19
orttraining/orttraining/training_ops/cuda/controlflow/wait.h
Normal file
19
orttraining/orttraining/training_ops/cuda/controlflow/wait.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class WaitEvent final : public CudaKernel {
|
||||
public:
|
||||
WaitEvent(const OpKernelInfo& info) : CudaKernel(info) { }
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -131,6 +131,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv);
|
||||
#endif
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent);
|
||||
|
||||
#ifdef USE_NCCL
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllGather);
|
||||
|
|
@ -263,6 +266,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv)>,
|
||||
#endif
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent)>,
|
||||
|
||||
#ifdef USE_NCCL
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllGather)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue