Optimize Where CUDA kernel for UniLMV2. (#3799)

Co-authored-by: Vincent Wang <weicwang@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Vincent Wang 2020-05-07 10:23:54 +08:00 committed by GitHub
parent 65bfece19d
commit c222ed6327
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 208 additions and 151 deletions

View file

@ -21,6 +21,12 @@ enum class SimpleBroadcast : int32_t {
RightPerChannelBatchN = (int32_t)-5,
};
enum class BroadcastIndexType : int32_t {
NoBroadcast = (int32_t)0,
Scalar = (int32_t)1,
NeedCompute = (int32_t)2,
};
template <typename T>
class IConstantBuffer {
public:

View file

@ -73,6 +73,9 @@ struct TernaryElementwisePreparation {
TArray<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
TArray<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
TArray<fast_divmod> fdm_output_strides;
BroadcastIndexType a_index_type = BroadcastIndexType::NoBroadcast;
BroadcastIndexType b_index_type = BroadcastIndexType::NoBroadcast;
BroadcastIndexType c_index_type = BroadcastIndexType::NoBroadcast;
TernaryElementwisePreparation(const Tensor* a, const Tensor* b, const Tensor* c)
: a_tensor(a), b_tensor(b), c_tensor(c) {}
@ -108,16 +111,34 @@ struct TernaryElementwisePreparation {
}
};
if (a_shape != output_shape) {
bool has_need_compute = false;
if (a_shape.Size() == 1) {
a_index_type = BroadcastIndexType::Scalar;
} else if (a_shape != output_shape) {
padder(a_rank, a_shape, a_padded_strides);
a_index_type = BroadcastIndexType::NeedCompute;
has_need_compute = true;
}
if (b_shape != output_shape) {
padder(b_rank, b_shape, b_padded_strides);
if (b_shape.Size() == 1) {
b_index_type = BroadcastIndexType::Scalar;
} else if (b_shape != output_shape) {
padder(b_rank, b_shape, b_padded_strides);
b_index_type = BroadcastIndexType::NeedCompute;
has_need_compute = true;
}
if (c_shape != output_shape) {
if (c_shape.Size() == 1) {
c_index_type = BroadcastIndexType::Scalar;
} else if (c_shape != output_shape) {
padder(c_rank, c_shape, c_padded_strides);
c_index_type = BroadcastIndexType::NeedCompute;
has_need_compute = true;
}
if (!has_need_compute) {
output_rank_or_simple_broadcast = static_cast<size_t>(SimpleBroadcast::NoBroadcast);
return Status::OK();
}
TensorPitches output_pitches(output_shape.GetDims());
@ -154,10 +175,13 @@ Status Where<T>::ComputeInternal(OpKernelContext* context) const {
WhereImpl<CudaT>(
prepare.output_rank_or_simple_broadcast,
prepare.a_index_type,
prepare.a_padded_strides,
reinterpret_cast<const bool*>(prepare.a_tensor->template Data<bool>()),
prepare.b_index_type,
prepare.b_padded_strides,
reinterpret_cast<const CudaT*>(prepare.b_tensor->template Data<T>()),
prepare.c_index_type,
prepare.c_padded_strides,
reinterpret_cast<const CudaT*>(prepare.c_tensor->template Data<T>()),
prepare.fdm_output_strides,

View file

@ -10,7 +10,7 @@ namespace onnxruntime {
namespace cuda {
// broadcast by computing output coordinate from offset, using fast_divmod
template <typename T, bool cond_need_compute, bool x_need_compute, bool y_need_compute>
template <typename T, BroadcastIndexType CondIndexType, BroadcastIndexType XIndexType, BroadcastIndexType YIndexType, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _TenaryElementWise(
size_t output_rank,
const TArray<int64_t> cond_padded_strides,
@ -22,184 +22,208 @@ __global__ void _TenaryElementWise(
const TArray<fast_divmod> fdm_output_strides,
T* output_data,
CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG cond_index = (cond_need_compute ? 0 : id);
CUDA_LONG x_index = (x_need_compute ? 0 : id);
CUDA_LONG y_index = (y_need_compute ? 0 : id);
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
bool cond_value[NumElementsPerThread];
T x_value[NumElementsPerThread];
T y_value[NumElementsPerThread];
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
CUDA_LONG offset = id;
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
if (dim >= output_rank) {
break;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
CUDA_LONG cond_index = (CondIndexType == BroadcastIndexType::NoBroadcast ? id : 0);
CUDA_LONG x_index = (XIndexType == BroadcastIndexType::NoBroadcast ? id : 0);
CUDA_LONG y_index = (YIndexType == BroadcastIndexType::NoBroadcast ? id : 0);
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
if (dim >= output_rank) {
break;
}
int q, r;
fdm_output_strides[dim].divmod(offset, q, r);
if (CondIndexType == BroadcastIndexType::NeedCompute) {
cond_index += static_cast<int>(cond_padded_strides[dim]) * q;
}
if (XIndexType == BroadcastIndexType::NeedCompute) {
x_index += static_cast<int>(x_padded_strides[dim]) * q;
}
if (YIndexType == BroadcastIndexType::NeedCompute) {
y_index += static_cast<int>(y_padded_strides[dim]) * q;
}
offset = r;
}
cond_value[i] = cond_data[cond_index];
x_value[i] = x_data[x_index];
y_value[i] = y_data[y_index];
id += NumThreadsPerBlock;
}
int q, r;
fdm_output_strides[dim].divmod(offset, q, r);
if (cond_need_compute) {
cond_index += static_cast<int>(cond_padded_strides[dim]) * q;
}
if (x_need_compute) {
x_index += static_cast<int>(x_padded_strides[dim]) * q;
}
if (y_need_compute) {
y_index += static_cast<int>(y_padded_strides[dim]) * q;
}
offset = r;
}
output_data[id] = cond_data[cond_index] ? x_data[x_index] : y_data[y_index];
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = cond_value[i] ? x_value[i] : y_value[i];
id += NumThreadsPerBlock;
}
}
}
// for scalar broadcast or non-broadcast case
template <typename T>
template <typename T, BroadcastIndexType CondIndexType, BroadcastIndexType XIndexType, BroadcastIndexType YIndexType, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _TenaryElementWiseSimple(
const bool* cond_data,
const T* x_data,
const T* y_data,
T* output_data,
CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
output_data[id] = cond_data[id] ? x_data[id] : y_data[id];
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
bool cond_value[NumElementsPerThread];
T x_value[NumElementsPerThread];
T y_value[NumElementsPerThread];
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
cond_value[i] = cond_data[CondIndexType == BroadcastIndexType::NoBroadcast ? id : 0];
x_value[i] = x_data[XIndexType == BroadcastIndexType::NoBroadcast ? id : 0];
y_value[i] = y_data[YIndexType == BroadcastIndexType::NoBroadcast ? id : 0];
id += NumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = cond_value[i] ? x_value[i] : y_value[i];
id += NumThreadsPerBlock;
}
}
}
#define HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
case Y_INDEX_TYPE: { \
_TenaryElementWiseSimple<T, \
COND_INDEX_TYPE, \
X_INDEX_TYPE, \
Y_INDEX_TYPE, \
GridDim::maxThreadsPerBlock, \
GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(cond_data, \
x_data, \
y_data, \
output_data, \
N); \
} break
#define HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
case X_INDEX_TYPE: { \
switch(Y_INDEX_TYPE_VAL) { \
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
} \
} break
#define HANDLE_COND_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
case COND_INDEX_TYPE: { \
switch(X_INDEX_TYPE_VAL) { \
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
} \
} break
#define HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
case Y_INDEX_TYPE: { \
_TenaryElementWise<T, \
COND_INDEX_TYPE, \
X_INDEX_TYPE, \
Y_INDEX_TYPE, \
GridDim::maxThreadsPerBlock, \
GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(output_rank_or_simple_broadcast, \
cond_padded_strides, \
cond_data, \
x_padded_strides, \
x_data, \
y_padded_strides, \
y_data, \
fdm_output_strides, \
output_data, \
N); \
} break
#define HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
case X_INDEX_TYPE: { \
switch(Y_INDEX_TYPE_VAL) { \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NeedCompute); \
} \
} break
#define HANDLE_COND_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
case COND_INDEX_TYPE: { \
switch(X_INDEX_TYPE_VAL) { \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NeedCompute, Y_INDEX_TYPE_VAL); \
} \
} break
template <typename T>
void WhereImpl(
size_t output_rank_or_simple_broadcast,
BroadcastIndexType cond_index_type,
const TArray<int64_t>& cond_padded_strides,
const bool* cond_data,
BroadcastIndexType x_index_type,
const TArray<int64_t>& x_padded_strides,
const T* x_data,
BroadcastIndexType y_index_type,
const TArray<int64_t>& y_padded_strides,
const T* y_data,
const TArray<fast_divmod>& fdm_output_strides,
T* output_data,
size_t count) {
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
CUDA_LONG N = static_cast<CUDA_LONG>(count);
if (output_rank_or_simple_broadcast == static_cast<size_t>(SimpleBroadcast::NoBroadcast)) {
_TenaryElementWiseSimple<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
cond_data,
x_data,
y_data,
output_data,
N);
switch(cond_index_type) {
HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type);
HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::Scalar, x_index_type, y_index_type);
}
} else {
if (cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) {
_TenaryElementWise<T, true, true, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
} else if (cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) {
_TenaryElementWise<T, true, true, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
} else if (cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) {
_TenaryElementWise<T, true, false, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
} else if (!cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) {
_TenaryElementWise<T, false, true, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
} else if (cond_padded_strides.size_ && !x_padded_strides.size_ && !y_padded_strides.size_) {
_TenaryElementWise<T, true, false, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
} else if (!cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) {
_TenaryElementWise<T, false, true, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
} else if (!cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) {
_TenaryElementWise<T, false, false, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
} else {
_TenaryElementWise<T, false, false, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
cond_padded_strides,
cond_data,
x_padded_strides,
x_data,
y_padded_strides,
y_data,
fdm_output_strides,
output_data,
N);
}
switch(cond_index_type) {
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type);
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::Scalar, x_index_type, y_index_type);
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NeedCompute, x_index_type, y_index_type);
}
}
}
#define SPECIALIZED_IMPL(T) \
template void WhereImpl<T>(size_t output_rank_or_simple_broadcast, \
const TArray<int64_t>& cond_padded_strides, \
const bool* cond_data, \
const TArray<int64_t>& x_padded_strides, \
const T* x_data, \
const TArray<int64_t>& y_padded_strides, \
const T* y_data, \
const TArray<fast_divmod>& fdm_output_strides, \
T* output_data, \
#define SPECIALIZED_IMPL(T) \
template void WhereImpl<T>(size_t output_rank_or_simple_broadcast, \
BroadcastIndexType cond_index_type, \
const TArray<int64_t>& cond_padded_strides, \
const bool* cond_data, \
BroadcastIndexType x_index_type, \
const TArray<int64_t>& x_padded_strides, \
const T* x_data, \
BroadcastIndexType y_index_type, \
const TArray<int64_t>& y_padded_strides, \
const T* y_data, \
const TArray<fast_divmod>& fdm_output_strides, \
T* output_data, \
size_t count);
SPECIALIZED_IMPL(uint8_t)
@ -209,4 +233,4 @@ SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(half)
} // namespace cuda
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -12,10 +12,13 @@ namespace cuda {
template <typename T>
void WhereImpl(
size_t output_rank_or_simple_broadcast,
BroadcastIndexType cond_index_type,
const TArray<int64_t>& cond_padded_strides,
const bool* cond_data,
BroadcastIndexType x_index_type,
const TArray<int64_t>& x_padded_strides,
const T* x_data,
BroadcastIndexType y_index_type,
const TArray<int64_t>& y_padded_strides,
const T* y_data,
const TArray<fast_divmod>& fdm_output_strides,