Support GPU Event Operators (#3653)

* Add GPU event operators to support in-place updates in
gradient accumulator and optimizer for modifying the tensors
passing through those event operators.

* Address comment and polish code

* Merge shared code between CPU and GPU kernels

* Move event test to a new file

* Address comments

* Update onnxruntime/core/providers/cuda/gpu_data_transfer.cc
This commit is contained in:
Wei-Sheng Chin 2020-04-24 17:43:04 -07:00 committed by GitHub
parent d06763ac1c
commit 72b38f0a8b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 317 additions and 87 deletions

View file

@ -36,7 +36,10 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, streams_[exec_queue_id]));
} else if (src_device.Type() == OrtDevice::GPU) {
// copying between GPU, this is non-blocking
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault]));
// Copy only if the two addresses are different.
if (dst_data != src_data) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, streams_[kCudaStreamDefault]));
}
} else {
// copy from other CPU memory to GPU, this is blocking
CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice));

View file

@ -0,0 +1,122 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <algorithm>
#include <bitset>
#include <cmath>
#include <random>
#include <thread>
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/test_random_seed.h"
#include "test/util/include/default_providers.h"
#include "onnx/defs/attr_proto_util.h"
namespace onnxruntime {
namespace test {
// Run GPU op for GPU build. Otherwise, run GPU op.
void run_provider_specific_optest(OpTester& tester) {
RunOptions run_option;
#ifdef USE_CUDA
std::vector<std::unique_ptr<IExecutionProvider>> providers;
providers.push_back(DefaultCudaExecutionProvider());
#else
std::vector<std::unique_ptr<IExecutionProvider>> providers;
providers.push_back(DefaultCpuExecutionProvider());
#endif
tester.Run(
OpTester::ExpectResult::kExpectSuccess,
"",
std::unordered_set<std::string>(),
&run_option,
&providers);
}
void record_event(int64_t event_id) {
OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain);
test_record.AddInput<int64_t>("EventIdentifier", {}, {event_id});
test_record.AddInput<bool>("InputSignal", {}, {true});
test_record.AddOutput<bool>("OutputSignal", {}, {true});
run_provider_specific_optest(test_record);
}
void record_event_multiple_inputs_and_outputs(int64_t event_id) {
OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain);
test_record.AddInput<int64_t>("EventIdentifier", {}, {event_id});
test_record.AddInput<bool>("InputSignal", {}, {true});
test_record.AddInput<float>("Input1", {3}, {9.4f, 1.7f, 3.6f});
test_record.AddInput<float>("Input2", {1}, {1.6f});
test_record.AddOutput<bool>("OutputSignal", {}, {true});
test_record.AddOutput<float>("Output1", {3}, {9.4f, 1.7f, 3.6f});
test_record.AddOutput<float>("Output2", {1}, {1.6f});
run_provider_specific_optest(test_record);
}
void wait_event(int64_t event_id) {
OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain);
test_wait.AddInput<int64_t>("EventIdentifier", {}, {event_id});
test_wait.AddInput<bool>("InputSignal", {}, {true});
test_wait.AddOutput<bool>("OutputSignal", {}, {true});
run_provider_specific_optest(test_wait);
}
void wait_event_multiple_inputs_and_outputs(int64_t event_id) {
OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain);
test_wait.AddInput<int64_t>("EventIdentifier", {}, {event_id});
test_wait.AddInput<bool>("InputSignal", {}, {true});
test_wait.AddInput<float>("Input1", {1}, {1.6f});
test_wait.AddInput<float>("Input2", {3}, {9.4f, 1.7f, 3.6f});
test_wait.AddOutput<bool>("OutputSignal", {}, {true});
test_wait.AddOutput<float>("output1", {1}, {1.6f});
test_wait.AddOutput<float>("output2", {3}, {9.4f, 1.7f, 3.6f});
run_provider_specific_optest(test_wait);
}
TEST(Synchronization, RecordAndWaitEvent) {
const int64_t event_id = static_cast<int64_t>(1736);
record_event(event_id);
wait_event(event_id);
}
TEST(Synchronization, WaitNullEvent) {
wait_event(-1);
}
TEST(Synchronization, RecordAndWaitEventMultipleInputsAndOutputs) {
const int64_t event_id = static_cast<int64_t>(995);
record_event_multiple_inputs_and_outputs(event_id);
wait_event_multiple_inputs_and_outputs(event_id);
}
TEST(Synchronization, WaitAndRecordEvent) {
const int64_t event_id = static_cast<int64_t>(1228);
std::thread waiting_thread(wait_event, event_id);
std::this_thread::sleep_for(std::chrono::milliseconds(5));
std::thread recording_thread(record_event, event_id);
waiting_thread.join();
recording_thread.join();
}
TEST(Synchronization, WaitAndRecordEventMany) {
const size_t event_count = 16;
for (int i = 0; i < 8; ++i) {
std::thread thread_pool[2 * event_count];
for (int j = 0; j < static_cast<int>(event_count); ++j) {
thread_pool[j] = std::thread(wait_event, j);
thread_pool[j + event_count] = std::thread(record_event, j);
}
for (size_t j = 0; j < event_count; ++j) {
thread_pool[j].join();
thread_pool[j + event_count].join();
}
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -1798,53 +1798,6 @@ TEST(GradientCheckerTest, SliceGrad) {
}
}
void record_event(int64_t event_id) {
OpTester test_record("RecordEvent", 1, onnxruntime::kMSDomain);
test_record.AddInput<int64_t>("EventIdentifier", {}, {event_id});
test_record.AddInput<bool>("InputSignal", {}, {true});
test_record.AddOutput<bool>("OutputSignal", {}, {true});
test_record.Run();
}
void wait_event(int64_t event_id) {
OpTester test_wait("WaitEvent", 1, onnxruntime::kMSDomain);
test_wait.AddInput<int64_t>("EventIdentifier", {}, {event_id});
test_wait.AddInput<bool>("InputSignal", {}, {true});
test_wait.AddOutput<bool>("OutputSignal", {}, {true});
test_wait.Run();
}
TEST(Synchronization, RecordAndWaitEvent) {
const int64_t event_id = static_cast<int64_t>(1736);
record_event(event_id);
wait_event(event_id);
}
TEST(Synchronization, WaitAndRecordEvent) {
const int64_t event_id = static_cast<int64_t>(1228);
std::thread waiting_thread(wait_event, event_id);
std::this_thread::sleep_for(std::chrono::milliseconds(5));
std::thread recording_thread(record_event, event_id);
waiting_thread.join();
recording_thread.join();
}
TEST(Synchronization, WaitAndRecordEventMany) {
const size_t event_count = 16;
for (int i = 0; i < 8; ++i) {
std::thread thread_pool[2 * event_count];
for (int j = 0; j < static_cast<int>(event_count); ++j) {
thread_pool[j] = std::thread(wait_event, j);
thread_pool[j + event_count] = std::thread(record_event, j);
}
for (size_t j = 0; j < event_count; ++j) {
thread_pool[j].join();
thread_pool[j + event_count].join();
}
}
}
TEST(GradientCheckerTest, ExpandGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
template <int input_start, int output_start>
std::vector<std::pair<int, int>> AliasRange(int start, int end) {
std::vector<std::pair<int, int>> aliases;
for (int i = start; i < end; i++) {
aliases.push_back(std::pair<int, int>(input_start + i, output_start + i));
}
return aliases;
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -6,8 +6,16 @@
namespace onnxruntime {
namespace contrib {
void OrtEventPool::CheckRange(const int64_t id) const {
ORT_ENFORCE(
id >= 0 && id < MaxNumItems,
"Got id ", id,
". It should be in a range from 0 to ",
MaxNumItems, ".");
}
void OrtEventPool::SignalEvent(int64_t id) {
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
CheckRange(id);
std::unique_lock<std::mutex> lock(pool_[id].mutex);
pool_[id].signaled.store(true);
lock.unlock();
@ -15,18 +23,18 @@ void OrtEventPool::SignalEvent(int64_t id) {
};
void OrtEventPool::ResetEvent(int64_t id) {
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
CheckRange(id);
std::lock_guard<std::mutex> guard(pool_[id].mutex);
pool_[id].signaled.store(false);
};
bool OrtEventPool::QueryEvent(int64_t id) const {
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
CheckRange(id);
return pool_[id].signaled.load();
}
void OrtEventPool::WaitEvent(int64_t id) const {
ORT_ENFORCE(id >= 0 && id < MaxNumItems);
CheckRange(id);
std::unique_lock<std::mutex> lock(pool_[id].mutex);
pool_[id].cv.wait(lock, [this, id] { return pool_[id].signaled.load(); });
};

View file

@ -34,6 +34,8 @@ class OrtEventPool final {
OrtEventPool(const OrtEventPool&) = delete;
OrtEventPool& operator=(const OrtEventPool&) = delete;
void CheckRange(const int64_t event_id) const;
struct Item {
std::atomic<bool> signaled;
mutable std::mutex mutex;
@ -43,9 +45,11 @@ class OrtEventPool final {
signaled.store(false);
}
};
enum {
MaxNumItems = 4096
};
Item pool_[MaxNumItems];
};

View file

@ -1,19 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "record.h"
#include "orttraining/training_ops/cpu/controlflow/record.h"
#include "core/providers/cpu/tensor/utils.h"
#include "common.h"
namespace onnxruntime {
namespace contrib {
template <int input_start, int output_start>
std::vector<std::pair<int, int>> AliasRange(int start, int end) {
std::vector<std::pair<int, int>> aliases;
for (int i = start; i < end; i++) {
aliases.push_back(std::pair<int, int>(input_start + i, output_start + i));
}
return aliases;
void record_event_in_tensor(const Tensor& event_id_tensor) {
const int64_t event_id = *event_id_tensor.template Data<int64_t>();
ORT_ENFORCE(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent");
OrtEventPool::GetInstance().SignalEvent(event_id);
}
ONNX_OPERATOR_KERNEL_EX(
@ -28,12 +26,8 @@ ONNX_OPERATOR_KERNEL_EX(
RecordEvent);
Status RecordEvent::Compute(OpKernelContext* ctx) const {
const Tensor* event_id_tensor = ctx->Input<Tensor>(0);
const int64_t event_id = *event_id_tensor->template Data<int64_t>();
ORT_RETURN_IF_NOT(event_id != -1, "-1 is reserved for skip wait, so cannot be used in RecordEvent");
OrtEventPool::GetInstance().SignalEvent(event_id);
// Pass event-id tensor to event-recording helper function.
record_event_in_tensor(*ctx->Input<Tensor>(0));
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
const Tensor* X = ctx->Input<Tensor>(i_out + 1);

View file

@ -9,6 +9,9 @@
namespace onnxruntime {
namespace contrib {
// Record the event ID stored in the input tensor.
void record_event_in_tensor(const Tensor& event_id_tensor);
class RecordEvent final : public OpKernel {
public:
RecordEvent(const OpKernelInfo& info) : OpKernel(info) { }

View file

@ -1,19 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "wait.h"
#include "orttraining/training_ops/cpu/controlflow/wait.h"
#include "core/providers/cpu/tensor/utils.h"
#include "common.h"
namespace onnxruntime {
namespace contrib {
template <int input_start, int output_start>
std::vector<std::pair<int, int>> AliasRange(int start, int end) {
std::vector<std::pair<int, int>> aliases;
for (int i = start; i < end; i++) {
aliases.push_back(std::pair<int, int>(input_start + i, output_start + i));
void wait_event_in_tensor(const Tensor& event_id_tensor) {
const int64_t event_id = *event_id_tensor.template Data<int64_t>();
// -1 is reserved to skip wait event
if (event_id != -1) {
// Wait the event to be recorded by a RecordEvent operator.
OrtEventPool::GetInstance().WaitEvent(event_id);
// BUGBUG: seems this would cause hang when a event is being waited more than once
// Destory the recorded event.
OrtEventPool::GetInstance().ResetEvent(event_id);
}
return aliases;
}
ONNX_OPERATOR_KERNEL_EX(
@ -28,18 +32,7 @@ ONNX_OPERATOR_KERNEL_EX(
WaitEvent);
Status WaitEvent::Compute(OpKernelContext* ctx) const {
const Tensor* event_id_tensor = ctx->Input<Tensor>(0);
const int64_t event_id = *event_id_tensor->template Data<int64_t>();
// -1 is reserved to skip wait event
if (event_id != -1) {
// Wait the event to be recorded by a RecordEvent operator.
OrtEventPool::GetInstance().WaitEvent(event_id);
// BUGBUG: seems this would cause hang when a event is being waited more than once
// Destory the recorded event.
OrtEventPool::GetInstance().ResetEvent(event_id);
}
wait_event_in_tensor(*ctx->Input<Tensor>(0));
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
const Tensor* X = ctx->Input<Tensor>(i_out + 1);

View file

@ -2,8 +2,6 @@
// Licensed under the MIT License.
#pragma once
#include <thread>
#include <chrono>
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "event_pool.h"
@ -11,6 +9,9 @@
namespace onnxruntime {
namespace contrib {
// Wait for the event ID stored in the input tensor.
void wait_event_in_tensor(const Tensor& event_id_tensor);
class WaitEvent final : public OpKernel {
public:
WaitEvent(const OpKernelInfo& info) : OpKernel(info) { }

View file

@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/controlflow/record.h"
#include "core/providers/cpu/tensor/utils.h"
// Include RecordEvent's utility functions shared by CPU and GPU implementations.
#include "orttraining/training_ops/cpu/controlflow/common.h"
// Include event mechanism shared by CPU and GPU implementations.
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
#include "orttraining/training_ops/cpu/controlflow/record.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
RecordEvent,
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) /* Keep EventIdentifier in CPU */
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)),
RecordEvent);
Status RecordEvent::ComputeInternal(OpKernelContext* ctx) const {
// Reuse CPU helper to record event because event tensor is a CPU tensor.
onnxruntime::contrib::record_event_in_tensor(*ctx->Input<Tensor>(0));
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
// This iteration copies (i-1)-th input to i-th output.
const Tensor* X = ctx->Input<Tensor>(i_out + 1);
const TensorShape& data_shape = X->Shape();
Tensor* Y = ctx->Output(i_out, data_shape);
CopyTensor(*X, *Y);
}
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {
class RecordEvent final : public CudaKernel {
public:
RecordEvent(const OpKernelInfo& info) : CudaKernel(info) { }
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cuda/controlflow/wait.h"
#include "core/providers/cpu/tensor/utils.h"
// Include RecordEvent's utility functions shared by CPU and GPU implementations.
#include "orttraining/training_ops/cpu/controlflow/common.h"
// Include event mechanism shared by CPU and GPU implementations.
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
#include "orttraining/training_ops/cpu/controlflow/wait.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
WaitEvent,
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.InputMemoryType<OrtMemTypeCPUInput>(0) /* CPU variable */
.TypeConstraint("TInt64", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.Alias(onnxruntime::contrib::AliasRange<1, 0>(0, 1024)),
WaitEvent);
Status WaitEvent::ComputeInternal(OpKernelContext* ctx) const {
// Reuse CPU helper to wait event because event tensor is a CPU tensor.
onnxruntime::contrib::wait_event_in_tensor(*ctx->Input<Tensor>(0));
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
// This iteration copies (i-1)-th input to i-th output.
const Tensor* X = ctx->Input<Tensor>(i_out + 1);
const TensorShape& data_shape = X->Shape();
Tensor* Y = ctx->Output(i_out, data_shape);
CopyTensor(*X, *Y);
}
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
namespace onnxruntime {
namespace cuda {
class WaitEvent final : public CudaKernel {
public:
WaitEvent(const OpKernelInfo& info) : CudaKernel(info) { }
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime

View file

@ -131,6 +131,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent);
#ifdef USE_NCCL
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllGather);
@ -263,6 +266,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Recv)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent)>,
#ifdef USE_NCCL
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllGather)>,