Make TArray safer to use and update method name for consistency. (#4483)

- make size_ and data_ data members private
- rename GetCapacity() to Capacity() to be consistent (e.g., with Size())
- add static_assert for trivially copyable T because it is copied with memcpy
This commit is contained in:
edgchen1 2020-07-13 09:59:56 -07:00 committed by GitHub
parent 00706e1502
commit c71c49aaa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 174 additions and 172 deletions

View file

@ -53,18 +53,18 @@ __global__ void _ElementWiseWithStrideTwo(
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
for (auto dim = 0; dim < fdm_output_strides.Capacity(); dim++) {
if (dim >= output_rank) {
break;
}
int q, r;
fdm_output_strides.data_[dim].divmod(offset, q, r);
fdm_output_strides[dim].divmod(offset, q, r);
if (lhs_need_compute) {
lhs_index += static_cast<int>(lhs_padded_strides.data_[dim]) * q;
lhs_index += static_cast<int>(lhs_padded_strides[dim]) * q;
}
if (rhs_need_compute) {
rhs_index += static_cast<int>(rhs_padded_strides.data_[dim]) * q;
rhs_index += static_cast<int>(rhs_padded_strides[dim]) * q;
}
offset = r;
}
@ -109,7 +109,7 @@ void ComplexMul_Impl(
int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
CUDA_LONG N = static_cast<CUDA_LONG>(count);
if (lhs_padded_strides && rhs_padded_strides && lhs_padded_strides->size_ && rhs_padded_strides->size_)
if (lhs_padded_strides && rhs_padded_strides && lhs_padded_strides->Size() && rhs_padded_strides->Size())
_ElementWiseWithStrideTwo<T, true, true, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
*lhs_padded_strides,
@ -122,7 +122,7 @@ void ComplexMul_Impl(
lhs_size,
rhs_size,
is_conj);
else if (lhs_padded_strides && lhs_padded_strides->size_)
else if (lhs_padded_strides && lhs_padded_strides->Size())
_ElementWiseWithStrideTwo<T, true, false, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank_or_simple_broadcast,
*lhs_padded_strides,

View file

@ -34,7 +34,7 @@ __global__ void _BinaryElementWise(
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
for (auto dim = 0; dim < fdm_output_strides.Capacity(); dim++) {
if (dim >= output_rank) {
break;
}

View file

@ -57,6 +57,7 @@ struct TArray {
}
TArray(const std::vector<T>& vec) : TArray(static_cast<int32_t>(vec.size())) {
static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable.");
memcpy(data_, vec.data(), vec.size() * sizeof(T));
}
@ -87,9 +88,9 @@ struct TArray {
return data_;
}
static constexpr int32_t GetCapacity() { return capacity; };
static constexpr int32_t Capacity() { return capacity; };
public: // TODO make these private
private:
int32_t size_;
T data_[capacity];
};

View file

@ -63,7 +63,7 @@ __global__ void ExpandKernel(
CUDA_LONG index = 0;
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < output_strides.GetCapacity(); dim++) {
for (auto dim = 0; dim < output_strides.Capacity(); dim++) {
if (dim >= rank) {
break;
}
@ -143,7 +143,7 @@ Status ExpandImpl(
void* output_data,
const TArray<fast_divmod>& output_strides,
const TArray<int64_t>& input_strides) {
const int rank = static_cast<int>(output_strides.size_);
const int rank = static_cast<int>(output_strides.Size());
if (rank == 1) {
if (N_input == N_output) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, N_output * element_size, cudaMemcpyDeviceToDevice));
@ -159,12 +159,12 @@ Status ExpandImpl(
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(N_output, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
#define EXPAND_ON(TYPE) \
case sizeof(TYPE): \
ExpandKernel<TYPE, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>( \
rank, N_output, reinterpret_cast<const TYPE*>(input_data), reinterpret_cast<TYPE*>(output_data), \
output_strides, input_strides); \
#define EXPAND_ON(TYPE) \
case sizeof(TYPE): \
ExpandKernel<TYPE, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>( \
rank, N_output, reinterpret_cast<const TYPE*>(input_data), reinterpret_cast<TYPE*>(output_data), \
output_strides, input_strides); \
break
switch (element_size) {

View file

@ -467,7 +467,7 @@ void ResizeNearestImpl(
bool could2d = rank >= 2 &&
transform_coordinate != GetDeviceOriginalCoordinateFunc(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE) &&
std::all_of(scales_vals.data_, scales_vals.data_ + (rank - 2), [](float v) { return v == 1.0; });
std::all_of(scales_vals.Data(), scales_vals.Data() + (rank - 2), [](float v) { return v == 1.0; });
if (could2d) {
int64_t output_height = output_shape[rank - 2];
int64_t output_width = output_shape[rank - 1];
@ -502,7 +502,7 @@ void ResizeNearestImpl(
return;
}
int64_t total_dim_sum = std::accumulate(output_shape.data_, output_shape.data_ + rank, 0);
int64_t total_dim_sum = std::accumulate(output_shape.Data(), output_shape.Data() + rank, 0);
int blocksPerDimsMappingGrid = (int)(ceil(static_cast<double>(total_dim_sum) / 32));
_ResizeNearestMappingKernel<T><<<blocksPerDimsMappingGrid, 32, 0>>>(
rank, input_shape, output_shape,
@ -540,7 +540,7 @@ void ResizeImpl(
ResizeCoordinateTransformationMode coordinate_transform_mode,
ResizeNearestMode nearest_mode,
void* dims_mapping) {
bool isSame = std::all_of(scales_vals.data_, scales_vals.data_ + rank, [](float v) { return v == 1.0f; }) &&
bool isSame = std::all_of(scales_vals.Data(), scales_vals.Data() + rank, [](float v) { return v == 1.0f; }) &&
(coordinate_transform_mode != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE);
if (isSame) {
cudaMemcpyAsync(output_data, input_data, N * sizeof(T), cudaMemcpyDeviceToDevice);
@ -606,12 +606,12 @@ void ResizeImpl(
template void ResizeImpl<T>( \
const UpsampleMode upsample_mode, \
const int rank, \
TArray<int64_t>& input_shape, \
TArray<int64_t>& output_shape, \
TArray<int64_t>& input_strides, \
TArray<fast_divmod>& output_div_pitches, \
TArray<float>& scales_vals, \
TArray<float>& roi_vals, \
TArray<int64_t>& input_shape, \
TArray<int64_t>& output_shape, \
TArray<int64_t>& input_strides, \
TArray<fast_divmod>& output_div_pitches, \
TArray<float>& scales_vals, \
TArray<float>& roi_vals, \
const T* input_data, \
T* output_data, \
const size_t N, \

View file

@ -192,7 +192,7 @@ Status ScatterElementsImplInternal(
std::vector<int64_t> eff_input_dims;
std::vector<int64_t> eff_indices_dims;
int new_axis = CompactInputIndicesDims(
rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims);
rank, axis, buffer_input_dims.Data(), buffer_indices_dims.Data(), eff_input_dims, eff_indices_dims);
if (eff_input_dims.size() == 2) {
return ScatterElementsImpl2D(
input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data,

View file

@ -101,7 +101,7 @@ Status Slice<dynamic>::ComputeInternal(OpKernelContext* ctx) const {
TArray<int64_t> starts_buffer(starts);
TArray<int64_t> steps_buffer(steps);
TArray<int64_t> input_strides(gsl::narrow_cast<int32_t>(dimension_count));
const gsl::span<int64_t> input_strides_span = gsl::make_span(input_strides.data_, input_strides.size_);
const gsl::span<int64_t> input_strides_span = gsl::make_span(input_strides.Data(), input_strides.Size());
if (p_flattened_output_dims != nullptr) {
// we were able to flatten the innermost dimensions as they're being copied in full to the output.
// do the same flattening to the innermost input dimensions in order to calculate pitches that match

View file

@ -24,16 +24,16 @@ __global__ void _SliceKernel(const int32_t dimension_count,
int value = id;
int dim = 0;
#pragma unroll
for (; dim < starts.GetCapacity(); ++dim) {
for (; dim < starts.Capacity(); ++dim) {
if (dim >= dimension_count - 1) {
break;
}
output_strides.data_[dim].divmod(value, div, mod);
input_index += (starts.data_[dim] + div * steps.data_[dim]) * input_strides.data_[dim];
output_strides[dim].divmod(value, div, mod);
input_index += (starts[dim] + div * steps[dim]) * input_strides[dim];
value = mod;
}
input_index += starts.data_[dim] + mod * steps.data_[dim];
input_index += starts[dim] + mod * steps[dim];
if (is_grad)
output_data[input_index] = input_data[id];
else

View file

@ -13,7 +13,7 @@ template <typename T>
__global__ void Transpose3DKernel(const TArray<int64_t> input_shape,
const TArray<int64_t> input_strides,
const T* input_data, T* output_data) {
__shared__ T tile[TILE_DIM * (TILE_DIM+1)];
__shared__ T tile[TILE_DIM * (TILE_DIM + 1)];
int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
@ -32,9 +32,9 @@ bool CanDoTranspose3D(int32_t rank,
const std::vector<size_t>& permutations) {
if (rank == 3 &&
// permutation is done in the last two dimensions.
permutations[rank-2] == (rank-1) && permutations[rank-1] == (rank-2) &&
permutations[rank - 2] == (rank - 1) && permutations[rank - 1] == (rank - 2) &&
// the last two dimensions are aligned with TILE_DIM.
input_dims[rank-2] % TILE_DIM == 0 && input_dims[rank-1] % TILE_DIM == 0) {
input_dims[rank - 2] % TILE_DIM == 0 && input_dims[rank - 1] % TILE_DIM == 0) {
return true;
}
return false;
@ -44,7 +44,7 @@ Status Transpose3DImpl(size_t element_size,
const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides,
const void* input_data, void* output_data, int64_t N) {
dim3 block_size(TILE_DIM, TILE_DIM);
dim3 grid_size(input_shape[2]/TILE_DIM, input_shape[1]/TILE_DIM, input_shape[0]);
dim3 grid_size(input_shape[2] / TILE_DIM, input_shape[1] / TILE_DIM, input_shape[0]);
switch (element_size) {
case sizeof(int8_t):
@ -73,7 +73,7 @@ Status Transpose3DImpl(size_t element_size,
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
element_size);
}
return Status::OK();
@ -86,13 +86,15 @@ __global__ void Transpose4DKernel(const TArray<int64_t> input_strides, const voi
// output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x
CUDA_LONG input_index = (blockIdx.y * input_strides[0] +
blockIdx.x * input_strides[1] +
threadIdx.y * input_strides[2]) / (4 * sizeof(int) / element_size) +
threadIdx.x * input_strides[3];
threadIdx.y * input_strides[2]) /
(4 * sizeof(int) / element_size) +
threadIdx.x * input_strides[3];
CUDA_LONG output_index = (blockIdx.y * output_strides[0] +
blockIdx.x * output_strides[1] +
threadIdx.y * output_strides[2]) / (4 * sizeof(int) / element_size) +
threadIdx.x * output_strides[3];
threadIdx.y * output_strides[2]) /
(4 * sizeof(int) / element_size) +
threadIdx.x * output_strides[3];
const int4* v_input = reinterpret_cast<const int4*>(input_data);
int4* v_output = reinterpret_cast<int4*>(output_data);
@ -109,12 +111,11 @@ bool CanDoTranspose4D(const cudaDeviceProp& prop,
const std::vector<size_t>& permutations) {
if (rank == 4 &&
// the permutations is not on the last dimension.
permutations[rank-1] == (rank - 1)) {
permutations[rank - 1] == (rank - 1)) {
// The block size will be set based on the last two dimensions of 4D tensor.
// the number threads per block will be calculated as below.
int num_elements_per_thread = 4 * sizeof(int) / element_size; // int4 is used in the kernel to access data.
int64_t num_elements_in_last_two_dimensions = input_dims[rank-2] * input_dims[rank-1];
int num_elements_per_thread = 4 * sizeof(int) / element_size; // int4 is used in the kernel to access data.
int64_t num_elements_in_last_two_dimensions = input_dims[rank - 2] * input_dims[rank - 1];
int64_t num_threads_per_block = num_elements_in_last_two_dimensions / num_elements_per_thread;
if (((num_elements_in_last_two_dimensions & (num_elements_per_thread - 1)) == 0) &&
@ -130,34 +131,34 @@ bool CanDoTranspose4D(const cudaDeviceProp& prop,
Status Transpose4DImpl(size_t element_size, const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int64_t N) {
int num_elements_per_thread = 4 * sizeof(int) / element_size; // int4 is used in the kernel to access data.
dim3 block_size(input_shape[3]/num_elements_per_thread, input_shape[2]);
int num_elements_per_thread = 4 * sizeof(int) / element_size; // int4 is used in the kernel to access data.
dim3 block_size(input_shape[3] / num_elements_per_thread, input_shape[2]);
dim3 grid_size(input_shape[1], input_shape[0]);
switch (element_size) {
case sizeof(int8_t):
Transpose4DKernel<sizeof(int8_t)><<<grid_size, block_size, 0>>>(
input_strides, input_data,
output_strides, output_data, N/num_elements_per_thread);
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int16_t):
Transpose4DKernel<sizeof(int16_t)><<<grid_size, block_size, 0>>>(
input_strides, input_data,
output_strides, output_data, N/num_elements_per_thread);
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int32_t):
Transpose4DKernel<sizeof(int32_t)><<<grid_size, block_size, 0>>>(
input_strides, input_data,
output_strides, output_data, N/num_elements_per_thread);
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int64_t):
Transpose4DKernel<sizeof(int64_t)><<<grid_size, block_size, 0>>>(
input_strides, input_data,
output_strides, output_data, N/num_elements_per_thread);
output_strides, output_data, N / num_elements_per_thread);
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
element_size);
}
return Status::OK();
@ -170,8 +171,8 @@ __global__ void TransposeKernel(int32_t shape_rank, const TArray<int64_t> input_
CUDA_LONG input_index = 0;
CUDA_LONG output_index = id;
#pragma unroll
for (auto dim = 0; dim < input_strides.GetCapacity(); ++dim) {
#pragma unroll
for (auto dim = 0; dim < input_strides.Capacity(); ++dim) {
if (dim >= shape_rank) {
break;
}

View file

@ -68,10 +68,10 @@ struct TernaryElementwisePreparation {
const Tensor* a_tensor = nullptr;
const Tensor* b_tensor = nullptr;
const Tensor* c_tensor = nullptr;
size_t output_rank_or_simple_broadcast = 0; // for no_broadcast cases, output_rank uses SimpleBroadcast enums
TArray<int64_t> a_padded_strides; // for a shape == output shape, this is nullptr
TArray<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
TArray<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
size_t output_rank_or_simple_broadcast = 0; // for no_broadcast cases, output_rank uses SimpleBroadcast enums
TArray<int64_t> a_padded_strides; // for a shape == output shape, this is nullptr
TArray<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
TArray<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
TArray<fast_divmod> fdm_output_strides;
BroadcastIndexType a_index_type = BroadcastIndexType::NoBroadcast;
BroadcastIndexType b_index_type = BroadcastIndexType::NoBroadcast;
@ -98,7 +98,7 @@ struct TernaryElementwisePreparation {
output_rank_or_simple_broadcast = out_rank;
auto padder = [out_rank](int32_t rank, const TensorShape& shape, TArray<int64_t>& padded_strides) {
padded_strides.size_ = out_rank;
padded_strides.SetSize(out_rank);
if (rank > 0) {
TensorPitches pitches(shape.GetDims());
auto offset = out_rank - rank;
@ -142,7 +142,7 @@ struct TernaryElementwisePreparation {
}
TensorPitches output_pitches(output_shape.GetDims());
fdm_output_strides.size_ = out_rank;
fdm_output_strides.SetSize(out_rank);
for (auto i = 0; i < out_rank; ++i) {
fdm_output_strides[i] = fast_divmod(static_cast<int32_t>(output_pitches[i]));
}

View file

@ -37,7 +37,7 @@ __global__ void _TenaryElementWise(
CUDA_LONG y_index = (YIndexType == BroadcastIndexType::NoBroadcast ? id : 0);
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
for (auto dim = 0; dim < fdm_output_strides.Capacity(); dim++) {
if (dim >= output_rank) {
break;
}
@ -111,73 +111,73 @@ __global__ void _TenaryElementWiseSimple(
}
}
#define HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
case Y_INDEX_TYPE: { \
_TenaryElementWiseSimple<T, \
COND_INDEX_TYPE, \
X_INDEX_TYPE, \
Y_INDEX_TYPE, \
GridDim::maxThreadsPerBlock, \
GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(cond_data, \
x_data, \
y_data, \
output_data, \
N); \
#define HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
case Y_INDEX_TYPE: { \
_TenaryElementWiseSimple<T, \
COND_INDEX_TYPE, \
X_INDEX_TYPE, \
Y_INDEX_TYPE, \
GridDim::maxThreadsPerBlock, \
GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(cond_data, \
x_data, \
y_data, \
output_data, \
N); \
} break
#define HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
case X_INDEX_TYPE: { \
switch(Y_INDEX_TYPE_VAL) { \
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
} \
#define HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
case X_INDEX_TYPE: { \
switch (Y_INDEX_TYPE_VAL) { \
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
HANDLE_Y_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
} \
} break
#define HANDLE_COND_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
case COND_INDEX_TYPE: { \
switch(X_INDEX_TYPE_VAL) { \
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
} \
#define HANDLE_COND_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
case COND_INDEX_TYPE: { \
switch (X_INDEX_TYPE_VAL) { \
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE_SIMPLE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
} \
} break
#define HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
case Y_INDEX_TYPE: { \
_TenaryElementWise<T, \
COND_INDEX_TYPE, \
X_INDEX_TYPE, \
Y_INDEX_TYPE, \
GridDim::maxThreadsPerBlock, \
GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(output_rank_or_simple_broadcast, \
cond_padded_strides, \
cond_data, \
x_padded_strides, \
x_data, \
y_padded_strides, \
y_data, \
fdm_output_strides, \
output_data, \
N); \
#define HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE) \
case Y_INDEX_TYPE: { \
_TenaryElementWise<T, \
COND_INDEX_TYPE, \
X_INDEX_TYPE, \
Y_INDEX_TYPE, \
GridDim::maxThreadsPerBlock, \
GridDim::maxElementsPerThread> \
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(output_rank_or_simple_broadcast, \
cond_padded_strides, \
cond_data, \
x_padded_strides, \
x_data, \
y_padded_strides, \
y_data, \
fdm_output_strides, \
output_data, \
N); \
} break
#define HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
case X_INDEX_TYPE: { \
switch(Y_INDEX_TYPE_VAL) { \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NeedCompute); \
} \
#define HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, Y_INDEX_TYPE_VAL) \
case X_INDEX_TYPE: { \
switch (Y_INDEX_TYPE_VAL) { \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NoBroadcast); \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::Scalar); \
HANDLE_Y_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE, BroadcastIndexType::NeedCompute); \
} \
} break
#define HANDLE_COND_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
case COND_INDEX_TYPE: { \
switch(X_INDEX_TYPE_VAL) { \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NeedCompute, Y_INDEX_TYPE_VAL); \
} \
#define HANDLE_COND_INDEX_TYPE(COND_INDEX_TYPE, X_INDEX_TYPE_VAL, Y_INDEX_TYPE_VAL) \
case COND_INDEX_TYPE: { \
switch (X_INDEX_TYPE_VAL) { \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NoBroadcast, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::Scalar, Y_INDEX_TYPE_VAL); \
HANDLE_X_INDEX_TYPE(COND_INDEX_TYPE, BroadcastIndexType::NeedCompute, Y_INDEX_TYPE_VAL); \
} \
} break
template <typename T>
@ -198,12 +198,12 @@ void WhereImpl(
int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
CUDA_LONG N = static_cast<CUDA_LONG>(count);
if (output_rank_or_simple_broadcast == static_cast<size_t>(SimpleBroadcast::NoBroadcast)) {
switch(cond_index_type) {
switch (cond_index_type) {
HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type);
HANDLE_COND_INDEX_TYPE_SIMPLE(BroadcastIndexType::Scalar, x_index_type, y_index_type);
}
} else {
switch(cond_index_type) {
switch (cond_index_type) {
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NoBroadcast, x_index_type, y_index_type);
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::Scalar, x_index_type, y_index_type);
HANDLE_COND_INDEX_TYPE(BroadcastIndexType::NeedCompute, x_index_type, y_index_type);

View file

@ -177,19 +177,19 @@ __global__ void _DivGrad(
CUDA_LONG a_index = (a_need_compute ? 0 : id);
CUDA_LONG b_index = (b_need_compute ? 0 : id);
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.Capacity(); dim++) {
if (dim >= output_rank) {
break;
}
int q, r;
fdm_output_strides.data_[dim].divmod(offset, q, r);
fdm_output_strides[dim].divmod(offset, q, r);
if (a_need_compute) {
a_index += static_cast<int>(a_padded_strides.data_[dim]) * q;
a_index += static_cast<int>(a_padded_strides[dim]) * q;
}
if (b_need_compute) {
b_index += static_cast<int>(b_padded_strides.data_[dim]) * q;
b_index += static_cast<int>(b_padded_strides[dim]) * q;
}
offset = r;
}
@ -209,15 +209,15 @@ __global__ void _DivGrad_A(
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG b_index = (b_need_compute ? 0 : id);
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.Capacity(); dim++) {
if (dim >= output_rank) {
break;
}
int q, r;
fdm_output_strides.data_[dim].divmod(offset, q, r);
fdm_output_strides[dim].divmod(offset, q, r);
if (b_need_compute) {
b_index += static_cast<int>(b_padded_strides.data_[dim]) * q;
b_index += static_cast<int>(b_padded_strides[dim]) * q;
}
offset = r;
}
@ -239,19 +239,19 @@ __global__ void _DivGrad_B(
CUDA_LONG a_index = (a_need_compute ? 0 : id);
CUDA_LONG b_index = (b_need_compute ? 0 : id);
CUDA_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.Capacity(); dim++) {
if (dim >= output_rank) {
break;
}
int q, r;
fdm_output_strides.data_[dim].divmod(offset, q, r);
fdm_output_strides[dim].divmod(offset, q, r);
if (a_need_compute) {
a_index += static_cast<int>(a_padded_strides.data_[dim]) * q;
a_index += static_cast<int>(a_padded_strides[dim]) * q;
}
if (b_need_compute) {
b_index += static_cast<int>(b_padded_strides.data_[dim]) * q;
b_index += static_cast<int>(b_padded_strides[dim]) * q;
}
offset = r;
}
@ -441,7 +441,7 @@ void ImplDivGrad(
T* db_output_data) {
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
CUDA_LONG N = static_cast<CUDA_LONG>(count);
if (a_padded_strides && a_padded_strides->size_ && b_padded_strides && b_padded_strides->size_) {
if (a_padded_strides && a_padded_strides->Size() && b_padded_strides && b_padded_strides->Size()) {
if (da_output_data && db_output_data)
_DivGrad<T, true, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank,
@ -474,7 +474,7 @@ void ImplDivGrad(
*fdm_output_strides,
db_output_data,
N);
} else if (a_padded_strides && a_padded_strides->size_) {
} else if (a_padded_strides && a_padded_strides->Size()) {
if (da_output_data && db_output_data)
_DivGrad<T, true, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
output_rank,
@ -543,42 +543,42 @@ void ImplDivGrad(
}
} // namespace cuda
#define SPECIALIZED_DIV_GRAD_IMPL(T) \
template void ImplDivGrad<T>( \
int32_t output_rank, \
const TArray<int64_t>* a_padded_strides, \
const T* a_data, \
const TArray<int64_t>* b_padded_strides, \
const T* b_data, \
const T* dy_data, \
size_t count, \
const TArray<fast_divmod>* fdm_output_strides,\
T* da_output_data, \
T* db_output_data); \
template void ImplDivGradRhsPerChannelBatch1<T>( \
const T* a_data, \
const T* b_data, \
const T* dy_data, \
size_t count, \
const fast_divmod& fdm_H, \
T* da_output_data, \
T* db_output_data); \
template void ImplDivGradRhsPerChannelBatchN<T>( \
const T* a_data, \
const T* b_data, \
const T* dy_data, \
size_t count, \
const fast_divmod& fdm_H, \
const fast_divmod& fdm_C, \
T* da_output_data, \
T* db_output_data); \
template void ImplDivGradSimple<T>( \
SimpleBroadcast simpleBroadcast, \
const T* a_data, \
const T* b_data, \
const T* dy_data, \
size_t count, \
T* da_output_data, \
#define SPECIALIZED_DIV_GRAD_IMPL(T) \
template void ImplDivGrad<T>( \
int32_t output_rank, \
const TArray<int64_t>* a_padded_strides, \
const T* a_data, \
const TArray<int64_t>* b_padded_strides, \
const T* b_data, \
const T* dy_data, \
size_t count, \
const TArray<fast_divmod>* fdm_output_strides, \
T* da_output_data, \
T* db_output_data); \
template void ImplDivGradRhsPerChannelBatch1<T>( \
const T* a_data, \
const T* b_data, \
const T* dy_data, \
size_t count, \
const fast_divmod& fdm_H, \
T* da_output_data, \
T* db_output_data); \
template void ImplDivGradRhsPerChannelBatchN<T>( \
const T* a_data, \
const T* b_data, \
const T* dy_data, \
size_t count, \
const fast_divmod& fdm_H, \
const fast_divmod& fdm_C, \
T* da_output_data, \
T* db_output_data); \
template void ImplDivGradSimple<T>( \
SimpleBroadcast simpleBroadcast, \
const T* a_data, \
const T* b_data, \
const T* dy_data, \
size_t count, \
T* da_output_data, \
T* db_output_data);
SPECIALIZED_DIV_GRAD_IMPL(half)