[CUDA] Split/Concat Kernel Optimization (#12175)

* split concat optimization

* bugfix

* fix ut

* deprecate LooseVersion
This commit is contained in:
Vincent Wang 2022-07-19 08:10:46 +08:00 committed by GitHub
parent ced7c2deac
commit 173bcdbc71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 394 additions and 401 deletions

View file

@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "concat.h"
#include "concat_impl.h"
#include "core/providers/cuda/tensor/concat.h"
#include "core/providers/cuda/tensor/concat_impl.h"
namespace onnxruntime {
namespace cuda {
@ -63,31 +64,41 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const {
axis_dimension_input_output_mapping.at(index++) = i;
}
}
std::vector<int64_t> concat_sizes_range(concat_sizes);
for (size_t i = 1; i < concat_sizes_range.size(); ++i) {
concat_sizes_range[i] += concat_sizes_range[i - 1];
}
CudaAsyncBuffer<int64_t> concat_sizes_gpu(this, concat_sizes);
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
CudaAsyncBuffer<int64_t> concat_sizes_range_gpu(this, concat_sizes_range);
ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu());
ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu());
ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu());
ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu());
auto element_bytes = p.output_tensor->DataType()->Size();
int block_size_inside_axis_dim = static_cast<int>(p.output_axis_pitch / p.output_tensor->Shape()[p.axis]);
int block_size_including_axis_dim = static_cast<int>(p.output_axis_pitch);
auto element_bytes = p.output_tensor->DataType()->Size();
ORT_RETURN_IF_ERROR(ConcatImpl(Stream(),
element_bytes,
block_size_including_axis_dim,
block_size_inside_axis_dim,
concat_sizes_gpu.GpuPtr(),
concat_sizes_range_gpu.GpuPtr(),
axis_dimension_input_output_mapping_gpu.GpuPtr(),
p.output_tensor->MutableDataRaw(),
input_ptr.GpuPtr(),
p.output_num_elements));
if (std::all_of(concat_sizes.begin(), concat_sizes.end(), [&](int64_t size) { return size == concat_sizes[0]; })) {
if (input_count <= 32) {
TArray<const void*, 32> input_ptr_array(input_count);
for (int i = 0; i < input_count; ++i) input_ptr_array[i] = input_ptr_cpuspan[i];
ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(
Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0],
p.output_tensor->MutableDataRaw(), input_ptr_array, static_cast<size_t>(p.output_num_elements)));
} else {
ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu());
ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(
Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0],
p.output_tensor->MutableDataRaw(), input_ptr.GpuPtr(), static_cast<size_t>(p.output_num_elements)));
}
} else {
CudaAsyncBuffer<int64_t> concat_sizes_gpu(this, concat_sizes);
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
std::vector<int64_t> concat_sizes_range(concat_sizes);
for (size_t i = 1; i < concat_sizes_range.size(); ++i) {
concat_sizes_range[i] += concat_sizes_range[i - 1];
}
CudaAsyncBuffer<int64_t> concat_sizes_range_gpu(this, concat_sizes_range);
ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu());
ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu());
ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu());
ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu());
ORT_RETURN_IF_ERROR(ConcatImpl(Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim,
concat_sizes_gpu.GpuPtr(), concat_sizes_range_gpu.GpuPtr(),
axis_dimension_input_output_mapping_gpu.GpuPtr(), p.output_tensor->MutableDataRaw(),
input_ptr.GpuPtr(), static_cast<size_t>(p.output_num_elements)));
}
return Status::OK();
}
} // namespace cuda

View file

@ -1,89 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/tensor/concat_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "concat_impl.h"
namespace onnxruntime {
namespace cuda {
namespace {
#ifdef USE_ROCM
constexpr int kNumElementsPerThread = 2;
constexpr int kNumThreadsPerBlock = 512;
#else
constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread;
constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock;
#endif
} // namespace
// concat dimension are same for all inputs
template <typename T, typename InputIndexToMemoryMap>
template <typename T, typename InputDataArray>
__global__ void _ConcatKernelSameConcatDim(const fast_divmod block_size_including_axis_dim_div,
const fast_divmod block_size_inside_axis_dim_div,
const fast_divmod concat_dim_size,
T* output_data,
InputIndexToMemoryMap input_ptr,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG input_pos = 0;
const fast_divmod block_size_inside_axis_dim_div,
const fast_divmod concat_dim_size, T* output_data, InputDataArray input_data,
const CUDA_LONG N) {
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[kNumElementsPerThread];
int outer_block_index = 0;
int block_index = 0;
int offset = 0;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
int outer_block_index, block_index, offset, input_index, block_offset;
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
concat_dim_size.divmod(block_index, input_index, block_offset);
CUDA_LONG input_pos =
(outer_block_index * concat_dim_size.d_ + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
value[i] = reinterpret_cast<const T*>(input_data[input_index])[input_pos];
id += kNumThreadsPerBlock;
}
}
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
int input_index = 0;
int block_offset = 0;
concat_dim_size.divmod(block_index, input_index, block_offset);
input_pos = (outer_block_index * concat_dim_size.d_ + block_offset) *
block_size_inside_axis_dim_div.d_ +
offset;
output_data[id] = reinterpret_cast<const T*>(input_ptr[input_index])[input_pos];
id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
output_data[id] = value[i];
id += kNumThreadsPerBlock;
}
}
}
template <typename InputIndexToMemoryMap>
Status ConcatSameConcatDimImpl(cudaStream_t stream,
const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t concat_size,
void* output_data,
const InputIndexToMemoryMap input_ptr,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
template <typename InputDataArray>
Status ConcatSameConcatDimImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data,
const InputDataArray input_data, const size_t output_size) {
CUDA_LONG N = static_cast<CUDA_LONG>(output_size);
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
fast_divmod concat_dim_size = fast_divmod(static_cast<int>(concat_size));
switch (element_bytes) {
case sizeof(int8_t):
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_dim_size,
reinterpret_cast<int8_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
case sizeof(int16_t):
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_dim_size,
reinterpret_cast<int16_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
case sizeof(int32_t):
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_dim_size,
reinterpret_cast<int32_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
case sizeof(int64_t):
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_dim_size,
reinterpret_cast<int64_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
_ConcatKernelSameConcatDim<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_dim_size, \
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data), input_data, N); \
} break
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator");
}
@ -92,103 +82,77 @@ Status ConcatSameConcatDimImpl(cudaStream_t stream,
}
// input tensors addresses in device memory
template Status ConcatSameConcatDimImpl<const void**>(cudaStream_t stream,
const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t concat_size,
void* output_data,
const void** input_ptr,
const size_t N);
template Status ConcatSameConcatDimImpl<const void**>(cudaStream_t stream, const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t concat_size,
void* output_data, const void** input_data,
const size_t output_size);
// input tensor addresses passed by value
template Status ConcatSameConcatDimImpl<TArray<const void*,32>>(cudaStream_t stream,
const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t concat_size,
void* output_data,
TArray<const void*,32> input_ptr,
const size_t N);
template Status ConcatSameConcatDimImpl<TArray<const void*, 32>>(cudaStream_t stream, const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t concat_size, void* output_data,
TArray<const void*, 32> input_data,
const size_t output_size);
template <typename T>
__global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_div,
const fast_divmod block_size_inside_axis_dim_div,
const int64_t* concat_sizes,
const int64_t* concat_sizes_range,
const int64_t* axis_dimension_input_output_mapping,
T* output_data,
const void** input_ptr,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG input_pos = 0;
const fast_divmod block_size_inside_axis_dim_div, const int64_t* concat_sizes,
const int64_t* concat_sizes_range, const int64_t* axis_dimension_input_output_mapping,
T* output_data, const void** input_data, const CUDA_LONG N) {
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[kNumElementsPerThread];
int outer_block_index = 0;
int block_index = 0;
int offset = 0;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
int outer_block_index, block_index, offset;
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
int input_index = axis_dimension_input_output_mapping[block_index];
int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1];
int block_offset = block_index - static_cast<int>(range_left);
CUDA_LONG input_pos =
(outer_block_index * concat_sizes[input_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
value[i] = reinterpret_cast<const T*>(input_data[input_index])[input_pos];
id += kNumThreadsPerBlock;
}
}
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
int input_index = axis_dimension_input_output_mapping[block_index];
int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1];
int block_offset = block_index - range_left;
input_pos = (outer_block_index * concat_sizes[input_index] + block_offset) *
block_size_inside_axis_dim_div.d_ +
offset;
output_data[id] = reinterpret_cast<const T*>(input_ptr[input_index])[input_pos];
id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
output_data[id] = value[i];
id += kNumThreadsPerBlock;
}
}
}
Status ConcatImpl(cudaStream_t stream,
const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t* concat_sizes,
const int64_t* concat_sizes_range,
const int64_t* axis_dimension_input_output_mapping,
void* output_data,
const void** input_ptr,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
Status ConcatImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range,
const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data,
const size_t output_size) {
CUDA_LONG N = static_cast<CUDA_LONG>(output_size);
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
switch (element_bytes) {
case sizeof(int8_t):
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
reinterpret_cast<int8_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
case sizeof(int16_t):
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
reinterpret_cast<int16_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
case sizeof(int32_t):
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
reinterpret_cast<int32_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
case sizeof(int64_t):
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
reinterpret_cast<int64_t*>(output_data),
input_ptr,
(CUDA_LONG)N);
break;
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
_ConcatKernel<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_sizes, concat_sizes_range, \
axis_dimension_input_output_mapping, reinterpret_cast<ToCudaType<type>::MappedType*>(output_data), input_data, \
N); \
} break;
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator");
}

View file

@ -9,26 +9,15 @@
namespace onnxruntime {
namespace cuda {
template<typename InputIndexToMemoryMap>
Status ConcatSameConcatDimImpl(cudaStream_t stream,
const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t concat_size,
void* output_data,
const InputIndexToMemoryMap input_ptr,
const size_t N);
template <typename InputDataArray>
Status ConcatSameConcatDimImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data,
const InputDataArray input_data, const size_t output_size);
Status ConcatImpl(cudaStream_t stream,
const size_t element_bytes,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t* concat_sizes,
const int64_t* concat_sizes_range,
const int64_t* axis_dimension_input_output_mapping,
void* output_data,
const void** input_ptr,
const size_t N);
Status ConcatImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range,
const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data,
const size_t output_size);
} // namespace cuda
} // namespace onnxruntime

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/cuda/tensor/split.h"
#include "core/providers/cuda/tensor/split_impl.h"
#include "core/providers/cpu/tensor/utils.h"
@ -34,7 +35,7 @@ ONNX_OPERATOR_KERNEL_EX(Split,
Status Split::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input_tensor = ctx->Input<Tensor>(0);
ORT_ENFORCE(nullptr != input_tensor);
ORT_ENFORCE(input_tensor);
auto& input_shape = input_tensor->Shape();
auto num_outputs = ctx->OutputCount();
int64_t axis = HandleNegativeAxis(axis_, input_shape.NumDimensions());
@ -44,7 +45,7 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const {
std::vector<int64_t> split_sizes(num_outputs);
const Tensor* split_tensor = ctx->Input<Tensor>(1);
if (split_tensor != nullptr) {
if (split_tensor) {
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "An split tensor must be a vector tensor.");
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
const int64_t* data = split_tensor->template Data<int64_t>();
@ -83,35 +84,38 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const {
}
}
if (input_tensor->Shape().Size() > 0) {
ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu());
if (input_tensor->Shape().Size() <= 0) return Status::OK();
size_t element_size = input_tensor->DataType()->Size();
if (std::all_of(split_sizes.begin(), split_sizes.end(), [&](int64_t size) { return size == split_sizes[0]; })) {
if (num_outputs <= 32) {
TArray<void*, 32> output_ptr_array(num_outputs);
for (int i = 0; i < num_outputs; ++i) output_ptr_array[i] = output_ptr_span[i];
ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), element_size, block_size_including_axis_dim,
block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data,
output_ptr_array, static_cast<size_t>(input_shape.Size())));
} else {
ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu());
ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), element_size, block_size_including_axis_dim,
block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data,
output_ptr.GpuPtr(), static_cast<size_t>(input_shape.Size())));
}
} else {
ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu());
CudaAsyncBuffer<int64_t> split_sizes_gpu(this, split_sizes);
ORT_RETURN_IF_ERROR(split_sizes_gpu.CopyToGpu());
std::vector<int64_t> split_sizes_range(split_sizes);
for (size_t i = 1; i < split_sizes_range.size(); ++i) {
split_sizes_range[i] += split_sizes_range[i - 1];
}
CudaAsyncBuffer<int64_t> split_sizes_range_gpu(this, split_sizes_range);
ORT_RETURN_IF_ERROR(split_sizes_range_gpu.CopyToGpu());
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu());
size_t element_size = input_tensor->DataType()->Size();
ORT_RETURN_IF_ERROR(SplitImpl(Stream(),
element_size,
block_size_including_axis_dim,
block_size_inside_axis_dim,
split_sizes_gpu.GpuPtr(),
split_sizes_range_gpu.GpuPtr(),
axis_dimension_input_output_mapping_gpu.GpuPtr(),
num_outputs,
input_data,
output_ptr.GpuPtr(),
input_shape.Size()));
ORT_RETURN_IF_ERROR(SplitImpl(Stream(), element_size, block_size_including_axis_dim, block_size_inside_axis_dim,
split_sizes_gpu.GpuPtr(), split_sizes_range_gpu.GpuPtr(),
axis_dimension_input_output_mapping_gpu.GpuPtr(), num_outputs, input_data,
output_ptr.GpuPtr(), static_cast<size_t>(input_shape.Size())));
}
return Status::OK();

View file

@ -1,91 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/tensor/split_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "split_impl.h"
namespace onnxruntime {
namespace cuda {
template <typename T, typename OutputIndexToMemoryMap>
namespace {
#ifdef USE_ROCM
constexpr int kNumElementsPerThread = 2;
constexpr int kNumThreadsPerBlock = 512;
#else
constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread;
constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock;
#endif
} // namespace
template <typename T, typename OutputDataArray>
__global__ void _SplitKernelSameSplitDim(const fast_divmod block_size_including_axis_dim_div,
const fast_divmod block_size_inside_axis_dim_div,
const fast_divmod split_dim_size,
const int num_outputs,
const T* input_data,
OutputIndexToMemoryMap output_ptr,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG output_pos = 0;
const fast_divmod block_size_inside_axis_dim_div,
const fast_divmod split_dim_size, const int num_outputs, const T* input_data,
OutputDataArray output_data, const CUDA_LONG N) {
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[kNumElementsPerThread];
int outer_block_index = 0;
int block_index = 0;
int offset = 0;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
value[i] = input_data[id];
id += kNumThreadsPerBlock;
}
}
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
int output_index = 0;
int block_offset = 0;
split_dim_size.divmod(block_index, output_index, block_offset);
output_pos = (outer_block_index * split_dim_size.d_ + block_offset) *
block_size_inside_axis_dim_div.d_ +
offset;
reinterpret_cast<T*>(output_ptr[output_index])[output_pos] = input_data[id];
id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
int outer_block_index, block_index, offset, output_index, block_offset;
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
split_dim_size.divmod(block_index, output_index, block_offset);
CUDA_LONG output_pos =
(outer_block_index * split_dim_size.d_ + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
reinterpret_cast<T*>(output_data[output_index])[output_pos] = value[i];
id += kNumThreadsPerBlock;
}
}
}
template <typename OutputIndexToMemoryMap>
Status SplitSameSplitDimImpl(cudaStream_t stream,
const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t split_size,
const int num_outputs,
const void* input_data,
OutputIndexToMemoryMap output_ptr,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
template <typename OutputDataArray>
Status SplitSameSplitDimImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t split_size, const int num_outputs,
const void* input_data, OutputDataArray output_data, const size_t input_size) {
CUDA_LONG N = static_cast<CUDA_LONG>(input_size);
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
fast_divmod split_size_div = fast_divmod((int)split_size);
fast_divmod split_size_div = fast_divmod(static_cast<int>(split_size));
switch (element_size) {
case sizeof(int8_t):
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_size_div, num_outputs,
reinterpret_cast<const ToCudaType<int8_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
case sizeof(int16_t):
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_size_div, num_outputs,
reinterpret_cast<const ToCudaType<int16_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
case sizeof(int32_t):
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_size_div, num_outputs,
reinterpret_cast<const ToCudaType<int32_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
case sizeof(int64_t):
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_size_div, num_outputs,
reinterpret_cast<const ToCudaType<int64_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
_SplitKernelSameSplitDim<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, split_size_div, num_outputs, \
reinterpret_cast<const ToCudaType<type>::MappedType*>(input_data), output_data, N); \
} break
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator");
}
@ -93,106 +81,75 @@ Status SplitSameSplitDimImpl(cudaStream_t stream,
return Status::OK();
}
template Status SplitSameSplitDimImpl<void**>(cudaStream_t stream,
const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t split_size,
const int num_outputs,
const void* input_data,
void** output_ptr,
const size_t N);
template Status SplitSameSplitDimImpl<void**>(cudaStream_t stream, const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t split_size,
const int num_outputs, const void* input_data, void** output_data,
const size_t input_size);
template Status SplitSameSplitDimImpl<TArray<void*, 32>>(cudaStream_t stream, const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t split_size,
const int num_outputs, const void* input_data,
TArray<void*, 32> output_data, const size_t input_size);
template Status SplitSameSplitDimImpl<TArray<void*,32>>(cudaStream_t stream,
const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t split_size,
const int num_outputs,
const void* input_data,
TArray<void*,32> output_ptr,
const size_t N);
template <typename T>
__global__ void _SplitKernel(const fast_divmod block_size_including_axis_dim_div,
const fast_divmod block_size_inside_axis_dim_div,
const int64_t* split_sizes,
const int64_t* split_sizes_range,
const int64_t* axis_dimension_input_output_mapping,
const int num_outputs,
const T* input_data,
void** output_ptr,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG output_pos = 0;
const fast_divmod block_size_inside_axis_dim_div, const int64_t* split_sizes,
const int64_t* split_sizes_range, const int64_t* axis_dimension_input_output_mapping,
const int num_outputs, const T* input_data, void** output_data, const CUDA_LONG N) {
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[kNumElementsPerThread];
int outer_block_index = 0;
int block_index = 0;
int offset = 0;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
value[i] = input_data[id];
id += kNumThreadsPerBlock;
}
}
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
int output_index = axis_dimension_input_output_mapping[block_index];
int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1];
int block_offset = block_index - range_left;
output_pos = (outer_block_index * split_sizes[output_index] + block_offset) *
block_size_inside_axis_dim_div.d_ +
offset;
reinterpret_cast<T*>(output_ptr[output_index])[output_pos] = input_data[id];
id = start;
#pragma unroll
for (int i = 0; i < kNumElementsPerThread; ++i) {
if (id < N) {
int outer_block_index, block_index, offset;
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
int output_index = axis_dimension_input_output_mapping[block_index];
int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1];
int block_offset = block_index - static_cast<int>(range_left);
CUDA_LONG output_pos =
(outer_block_index * split_sizes[output_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
reinterpret_cast<T*>(output_data[output_index])[output_pos] = value[i];
id += kNumThreadsPerBlock;
}
}
}
Status SplitImpl(cudaStream_t stream,
const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t* split_sizes,
const int64_t* split_sizes_range,
const int64_t* axis_dimension_input_output_mapping,
const int num_outputs,
const void* input_data,
void** output_ptr,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t* split_sizes, const int64_t* split_sizes_range,
const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data,
void** output_data, const size_t input_size) {
CUDA_LONG N = static_cast<CUDA_LONG>(input_size);
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
switch (element_size) {
case sizeof(int8_t):
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
reinterpret_cast<const ToCudaType<int8_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
case sizeof(int16_t):
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
reinterpret_cast<const ToCudaType<int16_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
case sizeof(int32_t):
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
reinterpret_cast<const ToCudaType<int32_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
case sizeof(int64_t):
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
reinterpret_cast<const ToCudaType<int64_t>::MappedType*>(input_data),
output_ptr,
(CUDA_LONG)N);
break;
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
_SplitKernel<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, split_sizes, split_sizes_range, \
axis_dimension_input_output_mapping, num_outputs, \
reinterpret_cast<const ToCudaType<type>::MappedType*>(input_data), output_data, N); \
} break
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator");
}

View file

@ -9,29 +9,15 @@
namespace onnxruntime {
namespace cuda {
template<typename OutputIndexToMemoryMap>
Status SplitSameSplitDimImpl(cudaStream_t stream,
const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t split_size,
const int num_outputs,
const void* input_data,
OutputIndexToMemoryMap output_ptr,
const size_t N);
template <typename OutputDataArray>
Status SplitSameSplitDimImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t split_size, const int num_outputs,
const void* input_data, OutputDataArray output_data, const size_t input_size);
Status SplitImpl(cudaStream_t stream,
const size_t element_size,
const int block_size_including_axis_dim,
const int block_size_inside_axis_dim,
const int64_t* split_sizes,
const int64_t* split_sizes_range,
const int64_t* axis_dimension_input_output_mapping,
const int num_outputs,
const void* input_data,
void** output_ptr,
const size_t N);
Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
const int block_size_inside_axis_dim, const int64_t* split_sizes, const int64_t* split_sizes_range,
const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data,
void** output_data, const size_t input_size);
} // namespace cuda
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
import torch
import torch.onnx.symbolic_helper as sym_help
from packaging.version import Version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
@ -20,7 +21,9 @@ class CustomOpSymbolicRegistry:
def register_all(cls):
for name, fn in cls._SYMBOLICS.items():
# Symbolic name is in format: domain::name
register_custom_op_symbolic(name, fn, 1)
# Exporter will fail to register symbolic with non-empty domain when torch version is < 1.11.0.
if Version(torch.__version__) >= Version("1.11.0") or name.startswith("::"):
register_custom_op_symbolic(name, fn, 1)
def register_symbolic(name, domain=""):
@ -220,6 +223,34 @@ def squeeze(g, self, dim=None):
return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim])
# Exporter's prim::ConstantChunk uses multiple Slice nodes, which is fine for inference.
# For training, the gradient graph will be multiple SliceGrad and one Sum, which is inefficient compared to
# exporting to Split with SplitGrad as gradient graph.
@register_symbolic("ConstantChunk", "prim")
def prim_ConstantChunk(g, self, chunks, dim):
input_shape_dim = g.op(
"Gather", g.op("Shape", self), g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)), axis_i=0
)
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))
chunk_dim = g.op(
"Div",
g.op("Add", input_shape_dim, chunk_size_minus_1),
g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)),
)
return g.op(
"Split",
self,
g.op(
"Concat",
g.op("Expand", chunk_dim, chunk_size_minus_1),
g.op("Sub", input_shape_dim, g.op("Mul", chunk_dim, chunk_size_minus_1)),
axis_i=0,
),
axis_i=dim,
outputs=chunks,
)
# For torch.einsum.
def parse_equation(equation):
pos_comma = equation.find(",")

View file

@ -11,14 +11,15 @@ import random
import tempfile
import warnings
from collections import OrderedDict, namedtuple
from distutils.version import LooseVersion
from inspect import signature
from time import sleep
from unittest.mock import patch
import _test_helpers
import onnx
import pytest
import torch
from packaging.version import Version
# Import autocasting libs
from torch.cuda import amp
@ -1095,7 +1096,7 @@ def test_export_correctness_pool2d(pool_type, stride):
torch.bfloat16,
marks=[
pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.10.0"),
Version(torch.__version__) < Version("1.10.0"),
reason="PyTorch 1.9 incompatible",
)
],
@ -1141,7 +1142,7 @@ def test_gradient_correctness_minmax(operator, dim, keepdim, data_type):
# Before 1.10 (excluded), Torch's min/max(x,y) will assign dY to y's dX if value from x and y are equal.
# From 1.10, both x and y's dX will be dY/2. ORT follows this distribution logic, so skip below test if Torch version
# is lower than 1.10.
@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.10.0"), reason="PyTorch 1.9 incompatible")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.10.0"), reason="PyTorch 1.9 incompatible")
@pytest.mark.parametrize("operator", ["min", "max"])
def test_gradient_correctness_minmax_two_tensors(operator):
func = getattr(torch, operator)
@ -1322,12 +1323,62 @@ def test_gradient_correctness_reducesum(dim, keepdim):
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
# Before PyTorch 1.11.0, the exporter will fail to register symbolic with non-empty domain.
@pytest.mark.skipif(Version(torch.__version__) < Version("1.11.0"), reason="PyTorch 1.10 incompatible")
@pytest.mark.parametrize("dim", [0, 1, -1])
def test_gradient_correctness_chunk(dim):
class NeuralNetChunk(torch.nn.Module):
def __init__(self, dim):
super(NeuralNetChunk, self).__init__()
self.dim = dim
def forward(self, input):
return input.chunk(3, dim=self.dim)
device = "cuda"
pt_model = NeuralNetChunk(dim).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="chunk_model"))
def run_step(model, input):
y1, y2, y3 = model(input)
loss = y1.sum() + y2.sum() + y3.sum()
loss.backward()
return y1, y2, y3
N, D, H = 16, 17, 18
for _ in range(10):
input = torch.rand((N, D, H), device=device, requires_grad=True)
pt_y1, pt_y2, pt_y3 = run_step(pt_model, input)
ort_y1, ort_y2, ort_y3 = run_step(ort_model, input)
_test_helpers.assert_values_are_close(ort_y1, pt_y1)
_test_helpers.assert_values_are_close(ort_y2, pt_y2)
_test_helpers.assert_values_are_close(ort_y3, pt_y3)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
has_split = False
for node in model.graph.node:
if node.op_type == "Split":
has_split = True
break
assert has_split
os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
# In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that
# contains ReduceProd will fail to run, for example, "sec,sm->ecm", "sec,ecm->sm".
# Currently skip these cases and test_gradient_correctness_einsum_2,
# will enable these tests again once the issue in PyTorch is fixed.
skip_torch_1_11 = pytest.mark.skipif(
LooseVersion(torch.__version__) >= LooseVersion("1.11.0"), reason="PyTorch 1.11 incompatible"
Version(torch.__version__) >= Version("1.11.0"), reason="PyTorch 1.11 incompatible"
)
@ -4688,7 +4739,7 @@ def test_ortmodule_ortmodule_method_attribute_copy():
assert type(out2.grad_fn).__name__ == "_ORTModuleFunctionBackward"
assert (
type(out3.grad_fn).__name__ == "AddmmBackward0"
if LooseVersion(torch.__version__) >= LooseVersion("1.10.0")
if Version(torch.__version__) >= Version("1.10.0")
else "AddmmBackward"
)