mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
[CUDA] Split/Concat Kernel Optimization (#12175)
* split concat optimization * bugfix * fix ut * deprecate LooseVersion
This commit is contained in:
parent
ced7c2deac
commit
173bcdbc71
8 changed files with 394 additions and 401 deletions
|
|
@ -1,8 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "concat.h"
|
||||
#include "concat_impl.h"
|
||||
#include "core/providers/cuda/tensor/concat.h"
|
||||
|
||||
#include "core/providers/cuda/tensor/concat_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
@ -63,31 +64,41 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const {
|
|||
axis_dimension_input_output_mapping.at(index++) = i;
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> concat_sizes_range(concat_sizes);
|
||||
for (size_t i = 1; i < concat_sizes_range.size(); ++i) {
|
||||
concat_sizes_range[i] += concat_sizes_range[i - 1];
|
||||
}
|
||||
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_gpu(this, concat_sizes);
|
||||
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_range_gpu(this, concat_sizes_range);
|
||||
ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu());
|
||||
auto element_bytes = p.output_tensor->DataType()->Size();
|
||||
int block_size_inside_axis_dim = static_cast<int>(p.output_axis_pitch / p.output_tensor->Shape()[p.axis]);
|
||||
int block_size_including_axis_dim = static_cast<int>(p.output_axis_pitch);
|
||||
auto element_bytes = p.output_tensor->DataType()->Size();
|
||||
ORT_RETURN_IF_ERROR(ConcatImpl(Stream(),
|
||||
element_bytes,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
concat_sizes_gpu.GpuPtr(),
|
||||
concat_sizes_range_gpu.GpuPtr(),
|
||||
axis_dimension_input_output_mapping_gpu.GpuPtr(),
|
||||
p.output_tensor->MutableDataRaw(),
|
||||
input_ptr.GpuPtr(),
|
||||
p.output_num_elements));
|
||||
if (std::all_of(concat_sizes.begin(), concat_sizes.end(), [&](int64_t size) { return size == concat_sizes[0]; })) {
|
||||
if (input_count <= 32) {
|
||||
TArray<const void*, 32> input_ptr_array(input_count);
|
||||
for (int i = 0; i < input_count; ++i) input_ptr_array[i] = input_ptr_cpuspan[i];
|
||||
ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(
|
||||
Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0],
|
||||
p.output_tensor->MutableDataRaw(), input_ptr_array, static_cast<size_t>(p.output_num_elements)));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(
|
||||
Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0],
|
||||
p.output_tensor->MutableDataRaw(), input_ptr.GpuPtr(), static_cast<size_t>(p.output_num_elements)));
|
||||
}
|
||||
} else {
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_gpu(this, concat_sizes);
|
||||
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
|
||||
std::vector<int64_t> concat_sizes_range(concat_sizes);
|
||||
for (size_t i = 1; i < concat_sizes_range.size(); ++i) {
|
||||
concat_sizes_range[i] += concat_sizes_range[i - 1];
|
||||
}
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_range_gpu(this, concat_sizes_range);
|
||||
ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(ConcatImpl(Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim,
|
||||
concat_sizes_gpu.GpuPtr(), concat_sizes_range_gpu.GpuPtr(),
|
||||
axis_dimension_input_output_mapping_gpu.GpuPtr(), p.output_tensor->MutableDataRaw(),
|
||||
input_ptr.GpuPtr(), static_cast<size_t>(p.output_num_elements)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -1,89 +1,79 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/tensor/concat_impl.h"
|
||||
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "concat_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
namespace {
|
||||
#ifdef USE_ROCM
|
||||
constexpr int kNumElementsPerThread = 2;
|
||||
constexpr int kNumThreadsPerBlock = 512;
|
||||
#else
|
||||
constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread;
|
||||
constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock;
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
// concat dimension are same for all inputs
|
||||
template <typename T, typename InputIndexToMemoryMap>
|
||||
template <typename T, typename InputDataArray>
|
||||
__global__ void _ConcatKernelSameConcatDim(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const fast_divmod concat_dim_size,
|
||||
T* output_data,
|
||||
InputIndexToMemoryMap input_ptr,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG input_pos = 0;
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const fast_divmod concat_dim_size, T* output_data, InputDataArray input_data,
|
||||
const CUDA_LONG N) {
|
||||
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
T value[kNumElementsPerThread];
|
||||
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
CUDA_LONG id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
int outer_block_index, block_index, offset, input_index, block_offset;
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
concat_dim_size.divmod(block_index, input_index, block_offset);
|
||||
CUDA_LONG input_pos =
|
||||
(outer_block_index * concat_dim_size.d_ + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
|
||||
value[i] = reinterpret_cast<const T*>(input_data[input_index])[input_pos];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int input_index = 0;
|
||||
int block_offset = 0;
|
||||
concat_dim_size.divmod(block_index, input_index, block_offset);
|
||||
|
||||
input_pos = (outer_block_index * concat_dim_size.d_ + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
output_data[id] = reinterpret_cast<const T*>(input_ptr[input_index])[input_pos];
|
||||
id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
output_data[id] = value[i];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InputIndexToMemoryMap>
|
||||
Status ConcatSameConcatDimImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
const InputIndexToMemoryMap input_ptr,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
template <typename InputDataArray>
|
||||
Status ConcatSameConcatDimImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data,
|
||||
const InputDataArray input_data, const size_t output_size) {
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(output_size);
|
||||
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
|
||||
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
|
||||
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
|
||||
fast_divmod concat_dim_size = fast_divmod(static_cast<int>(concat_size));
|
||||
switch (element_bytes) {
|
||||
case sizeof(int8_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int8_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int16_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int32_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int64_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
#define CASE_ELEMENT_TYPE(type) \
|
||||
case sizeof(type): { \
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_dim_size, \
|
||||
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data), input_data, N); \
|
||||
} break
|
||||
CASE_ELEMENT_TYPE(int8_t);
|
||||
CASE_ELEMENT_TYPE(int16_t);
|
||||
CASE_ELEMENT_TYPE(int32_t);
|
||||
CASE_ELEMENT_TYPE(int64_t);
|
||||
#undef CASE_ELEMENT_TYPE
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator");
|
||||
}
|
||||
|
|
@ -92,103 +82,77 @@ Status ConcatSameConcatDimImpl(cudaStream_t stream,
|
|||
}
|
||||
|
||||
// input tensors addresses in device memory
|
||||
template Status ConcatSameConcatDimImpl<const void**>(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
const void** input_ptr,
|
||||
const size_t N);
|
||||
template Status ConcatSameConcatDimImpl<const void**>(cudaStream_t stream, const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t concat_size,
|
||||
void* output_data, const void** input_data,
|
||||
const size_t output_size);
|
||||
|
||||
// input tensor addresses passed by value
|
||||
template Status ConcatSameConcatDimImpl<TArray<const void*,32>>(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
TArray<const void*,32> input_ptr,
|
||||
const size_t N);
|
||||
template Status ConcatSameConcatDimImpl<TArray<const void*, 32>>(cudaStream_t stream, const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size, void* output_data,
|
||||
TArray<const void*, 32> input_data,
|
||||
const size_t output_size);
|
||||
|
||||
template <typename T>
|
||||
__global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const int64_t* concat_sizes,
|
||||
const int64_t* concat_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping,
|
||||
T* output_data,
|
||||
const void** input_ptr,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG input_pos = 0;
|
||||
const fast_divmod block_size_inside_axis_dim_div, const int64_t* concat_sizes,
|
||||
const int64_t* concat_sizes_range, const int64_t* axis_dimension_input_output_mapping,
|
||||
T* output_data, const void** input_data, const CUDA_LONG N) {
|
||||
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
T value[kNumElementsPerThread];
|
||||
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
CUDA_LONG id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
int outer_block_index, block_index, offset;
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
int input_index = axis_dimension_input_output_mapping[block_index];
|
||||
int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1];
|
||||
int block_offset = block_index - static_cast<int>(range_left);
|
||||
CUDA_LONG input_pos =
|
||||
(outer_block_index * concat_sizes[input_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
|
||||
value[i] = reinterpret_cast<const T*>(input_data[input_index])[input_pos];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int input_index = axis_dimension_input_output_mapping[block_index];
|
||||
int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1];
|
||||
int block_offset = block_index - range_left;
|
||||
|
||||
input_pos = (outer_block_index * concat_sizes[input_index] + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
output_data[id] = reinterpret_cast<const T*>(input_ptr[input_index])[input_pos];
|
||||
id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
output_data[id] = value[i];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status ConcatImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t* concat_sizes,
|
||||
const int64_t* concat_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping,
|
||||
void* output_data,
|
||||
const void** input_ptr,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
Status ConcatImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data,
|
||||
const size_t output_size) {
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(output_size);
|
||||
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
|
||||
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
|
||||
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
|
||||
|
||||
switch (element_bytes) {
|
||||
case sizeof(int8_t):
|
||||
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
|
||||
reinterpret_cast<int8_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
|
||||
reinterpret_cast<int16_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
|
||||
reinterpret_cast<int32_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
_ConcatKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping,
|
||||
reinterpret_cast<int64_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
#define CASE_ELEMENT_TYPE(type) \
|
||||
case sizeof(type): { \
|
||||
_ConcatKernel<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_sizes, concat_sizes_range, \
|
||||
axis_dimension_input_output_mapping, reinterpret_cast<ToCudaType<type>::MappedType*>(output_data), input_data, \
|
||||
N); \
|
||||
} break;
|
||||
CASE_ELEMENT_TYPE(int8_t);
|
||||
CASE_ELEMENT_TYPE(int16_t);
|
||||
CASE_ELEMENT_TYPE(int32_t);
|
||||
CASE_ELEMENT_TYPE(int64_t);
|
||||
#undef CASE_ELEMENT_TYPE
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,26 +9,15 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename InputIndexToMemoryMap>
|
||||
Status ConcatSameConcatDimImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
const InputIndexToMemoryMap input_ptr,
|
||||
const size_t N);
|
||||
template <typename InputDataArray>
|
||||
Status ConcatSameConcatDimImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data,
|
||||
const InputDataArray input_data, const size_t output_size);
|
||||
|
||||
Status ConcatImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t* concat_sizes,
|
||||
const int64_t* concat_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping,
|
||||
void* output_data,
|
||||
const void** input_ptr,
|
||||
const size_t N);
|
||||
Status ConcatImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data,
|
||||
const size_t output_size);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/tensor/split.h"
|
||||
|
||||
#include "core/providers/cuda/tensor/split_impl.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
|
||||
|
|
@ -34,7 +35,7 @@ ONNX_OPERATOR_KERNEL_EX(Split,
|
|||
|
||||
Status Split::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* input_tensor = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(nullptr != input_tensor);
|
||||
ORT_ENFORCE(input_tensor);
|
||||
auto& input_shape = input_tensor->Shape();
|
||||
auto num_outputs = ctx->OutputCount();
|
||||
int64_t axis = HandleNegativeAxis(axis_, input_shape.NumDimensions());
|
||||
|
|
@ -44,7 +45,7 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> split_sizes(num_outputs);
|
||||
|
||||
const Tensor* split_tensor = ctx->Input<Tensor>(1);
|
||||
if (split_tensor != nullptr) {
|
||||
if (split_tensor) {
|
||||
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "An split tensor must be a vector tensor.");
|
||||
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
|
||||
const int64_t* data = split_tensor->template Data<int64_t>();
|
||||
|
|
@ -83,35 +84,38 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
if (input_tensor->Shape().Size() > 0) {
|
||||
ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu());
|
||||
if (input_tensor->Shape().Size() <= 0) return Status::OK();
|
||||
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
if (std::all_of(split_sizes.begin(), split_sizes.end(), [&](int64_t size) { return size == split_sizes[0]; })) {
|
||||
if (num_outputs <= 32) {
|
||||
TArray<void*, 32> output_ptr_array(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) output_ptr_array[i] = output_ptr_span[i];
|
||||
ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), element_size, block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data,
|
||||
output_ptr_array, static_cast<size_t>(input_shape.Size())));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), element_size, block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data,
|
||||
output_ptr.GpuPtr(), static_cast<size_t>(input_shape.Size())));
|
||||
}
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu());
|
||||
CudaAsyncBuffer<int64_t> split_sizes_gpu(this, split_sizes);
|
||||
ORT_RETURN_IF_ERROR(split_sizes_gpu.CopyToGpu());
|
||||
|
||||
std::vector<int64_t> split_sizes_range(split_sizes);
|
||||
for (size_t i = 1; i < split_sizes_range.size(); ++i) {
|
||||
split_sizes_range[i] += split_sizes_range[i - 1];
|
||||
}
|
||||
|
||||
CudaAsyncBuffer<int64_t> split_sizes_range_gpu(this, split_sizes_range);
|
||||
ORT_RETURN_IF_ERROR(split_sizes_range_gpu.CopyToGpu());
|
||||
|
||||
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
|
||||
ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu());
|
||||
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
ORT_RETURN_IF_ERROR(SplitImpl(Stream(),
|
||||
element_size,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
split_sizes_gpu.GpuPtr(),
|
||||
split_sizes_range_gpu.GpuPtr(),
|
||||
axis_dimension_input_output_mapping_gpu.GpuPtr(),
|
||||
num_outputs,
|
||||
input_data,
|
||||
output_ptr.GpuPtr(),
|
||||
input_shape.Size()));
|
||||
ORT_RETURN_IF_ERROR(SplitImpl(Stream(), element_size, block_size_including_axis_dim, block_size_inside_axis_dim,
|
||||
split_sizes_gpu.GpuPtr(), split_sizes_range_gpu.GpuPtr(),
|
||||
axis_dimension_input_output_mapping_gpu.GpuPtr(), num_outputs, input_data,
|
||||
output_ptr.GpuPtr(), static_cast<size_t>(input_shape.Size())));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -1,91 +1,79 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/tensor/split_impl.h"
|
||||
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "split_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, typename OutputIndexToMemoryMap>
|
||||
namespace {
|
||||
#ifdef USE_ROCM
|
||||
constexpr int kNumElementsPerThread = 2;
|
||||
constexpr int kNumThreadsPerBlock = 512;
|
||||
#else
|
||||
constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread;
|
||||
constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock;
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename OutputDataArray>
|
||||
__global__ void _SplitKernelSameSplitDim(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const fast_divmod split_dim_size,
|
||||
const int num_outputs,
|
||||
const T* input_data,
|
||||
OutputIndexToMemoryMap output_ptr,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG output_pos = 0;
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const fast_divmod split_dim_size, const int num_outputs, const T* input_data,
|
||||
OutputDataArray output_data, const CUDA_LONG N) {
|
||||
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
T value[kNumElementsPerThread];
|
||||
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
CUDA_LONG id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
value[i] = input_data[id];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int output_index = 0;
|
||||
int block_offset = 0;
|
||||
split_dim_size.divmod(block_index, output_index, block_offset);
|
||||
|
||||
output_pos = (outer_block_index * split_dim_size.d_ + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
reinterpret_cast<T*>(output_ptr[output_index])[output_pos] = input_data[id];
|
||||
id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
int outer_block_index, block_index, offset, output_index, block_offset;
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
split_dim_size.divmod(block_index, output_index, block_offset);
|
||||
CUDA_LONG output_pos =
|
||||
(outer_block_index * split_dim_size.d_ + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
|
||||
reinterpret_cast<T*>(output_data[output_index])[output_pos] = value[i];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutputIndexToMemoryMap>
|
||||
Status SplitSameSplitDimImpl(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
OutputIndexToMemoryMap output_ptr,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
template <typename OutputDataArray>
|
||||
Status SplitSameSplitDimImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t split_size, const int num_outputs,
|
||||
const void* input_data, OutputDataArray output_data, const size_t input_size) {
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(input_size);
|
||||
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
|
||||
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
|
||||
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
|
||||
fast_divmod split_size_div = fast_divmod((int)split_size);
|
||||
fast_divmod split_size_div = fast_divmod(static_cast<int>(split_size));
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(int8_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int8_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int16_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int32_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int64_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
#define CASE_ELEMENT_TYPE(type) \
|
||||
case sizeof(type): { \
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, split_size_div, num_outputs, \
|
||||
reinterpret_cast<const ToCudaType<type>::MappedType*>(input_data), output_data, N); \
|
||||
} break
|
||||
CASE_ELEMENT_TYPE(int8_t);
|
||||
CASE_ELEMENT_TYPE(int16_t);
|
||||
CASE_ELEMENT_TYPE(int32_t);
|
||||
CASE_ELEMENT_TYPE(int64_t);
|
||||
#undef CASE_ELEMENT_TYPE
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator");
|
||||
}
|
||||
|
|
@ -93,106 +81,75 @@ Status SplitSameSplitDimImpl(cudaStream_t stream,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status SplitSameSplitDimImpl<void**>(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
void** output_ptr,
|
||||
const size_t N);
|
||||
template Status SplitSameSplitDimImpl<void**>(cudaStream_t stream, const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t split_size,
|
||||
const int num_outputs, const void* input_data, void** output_data,
|
||||
const size_t input_size);
|
||||
|
||||
template Status SplitSameSplitDimImpl<TArray<void*, 32>>(cudaStream_t stream, const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t split_size,
|
||||
const int num_outputs, const void* input_data,
|
||||
TArray<void*, 32> output_data, const size_t input_size);
|
||||
|
||||
template Status SplitSameSplitDimImpl<TArray<void*,32>>(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
TArray<void*,32> output_ptr,
|
||||
const size_t N);
|
||||
|
||||
template <typename T>
|
||||
__global__ void _SplitKernel(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const int64_t* split_sizes,
|
||||
const int64_t* split_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping,
|
||||
const int num_outputs,
|
||||
const T* input_data,
|
||||
void** output_ptr,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG output_pos = 0;
|
||||
const fast_divmod block_size_inside_axis_dim_div, const int64_t* split_sizes,
|
||||
const int64_t* split_sizes_range, const int64_t* axis_dimension_input_output_mapping,
|
||||
const int num_outputs, const T* input_data, void** output_data, const CUDA_LONG N) {
|
||||
CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
T value[kNumElementsPerThread];
|
||||
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
CUDA_LONG id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
value[i] = input_data[id];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int output_index = axis_dimension_input_output_mapping[block_index];
|
||||
int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1];
|
||||
int block_offset = block_index - range_left;
|
||||
|
||||
output_pos = (outer_block_index * split_sizes[output_index] + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
reinterpret_cast<T*>(output_ptr[output_index])[output_pos] = input_data[id];
|
||||
id = start;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumElementsPerThread; ++i) {
|
||||
if (id < N) {
|
||||
int outer_block_index, block_index, offset;
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
int output_index = axis_dimension_input_output_mapping[block_index];
|
||||
int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1];
|
||||
int block_offset = block_index - static_cast<int>(range_left);
|
||||
CUDA_LONG output_pos =
|
||||
(outer_block_index * split_sizes[output_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset;
|
||||
reinterpret_cast<T*>(output_data[output_index])[output_pos] = value[i];
|
||||
id += kNumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status SplitImpl(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t* split_sizes,
|
||||
const int64_t* split_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
void** output_ptr,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t* split_sizes, const int64_t* split_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data,
|
||||
void** output_data, const size_t input_size) {
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(input_size);
|
||||
int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock);
|
||||
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
|
||||
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(int8_t):
|
||||
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int8_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int16_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int32_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
_SplitKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int64_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
#define CASE_ELEMENT_TYPE(type) \
|
||||
case sizeof(type): { \
|
||||
_SplitKernel<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>( \
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div, split_sizes, split_sizes_range, \
|
||||
axis_dimension_input_output_mapping, num_outputs, \
|
||||
reinterpret_cast<const ToCudaType<type>::MappedType*>(input_data), output_data, N); \
|
||||
} break
|
||||
CASE_ELEMENT_TYPE(int8_t);
|
||||
CASE_ELEMENT_TYPE(int16_t);
|
||||
CASE_ELEMENT_TYPE(int32_t);
|
||||
CASE_ELEMENT_TYPE(int64_t);
|
||||
#undef CASE_ELEMENT_TYPE
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,29 +9,15 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename OutputIndexToMemoryMap>
|
||||
Status SplitSameSplitDimImpl(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
OutputIndexToMemoryMap output_ptr,
|
||||
const size_t N);
|
||||
template <typename OutputDataArray>
|
||||
Status SplitSameSplitDimImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t split_size, const int num_outputs,
|
||||
const void* input_data, OutputDataArray output_data, const size_t input_size);
|
||||
|
||||
|
||||
Status SplitImpl(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t* split_sizes,
|
||||
const int64_t* split_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
void** output_ptr,
|
||||
const size_t N);
|
||||
Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim, const int64_t* split_sizes, const int64_t* split_sizes_range,
|
||||
const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data,
|
||||
void** output_data, const size_t input_size);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
import torch
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from packaging.version import Version
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
|
||||
|
||||
|
|
@ -20,7 +21,9 @@ class CustomOpSymbolicRegistry:
|
|||
def register_all(cls):
|
||||
for name, fn in cls._SYMBOLICS.items():
|
||||
# Symbolic name is in format: domain::name
|
||||
register_custom_op_symbolic(name, fn, 1)
|
||||
# Exporter will fail to register symbolic with non-empty domain when torch version is < 1.11.0.
|
||||
if Version(torch.__version__) >= Version("1.11.0") or name.startswith("::"):
|
||||
register_custom_op_symbolic(name, fn, 1)
|
||||
|
||||
|
||||
def register_symbolic(name, domain=""):
|
||||
|
|
@ -220,6 +223,34 @@ def squeeze(g, self, dim=None):
|
|||
return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim])
|
||||
|
||||
|
||||
# Exporter's prim::ConstantChunk uses multiple Slice nodes, which is fine for inference.
|
||||
# For training, the gradient graph will be multiple SliceGrad and one Sum, which is inefficient compared to
|
||||
# exporting to Split with SplitGrad as gradient graph.
|
||||
@register_symbolic("ConstantChunk", "prim")
|
||||
def prim_ConstantChunk(g, self, chunks, dim):
|
||||
input_shape_dim = g.op(
|
||||
"Gather", g.op("Shape", self), g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)), axis_i=0
|
||||
)
|
||||
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))
|
||||
chunk_dim = g.op(
|
||||
"Div",
|
||||
g.op("Add", input_shape_dim, chunk_size_minus_1),
|
||||
g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)),
|
||||
)
|
||||
return g.op(
|
||||
"Split",
|
||||
self,
|
||||
g.op(
|
||||
"Concat",
|
||||
g.op("Expand", chunk_dim, chunk_size_minus_1),
|
||||
g.op("Sub", input_shape_dim, g.op("Mul", chunk_dim, chunk_size_minus_1)),
|
||||
axis_i=0,
|
||||
),
|
||||
axis_i=dim,
|
||||
outputs=chunks,
|
||||
)
|
||||
|
||||
|
||||
# For torch.einsum.
|
||||
def parse_equation(equation):
|
||||
pos_comma = equation.find(",")
|
||||
|
|
|
|||
|
|
@ -11,14 +11,15 @@ import random
|
|||
import tempfile
|
||||
import warnings
|
||||
from collections import OrderedDict, namedtuple
|
||||
from distutils.version import LooseVersion
|
||||
from inspect import signature
|
||||
from time import sleep
|
||||
from unittest.mock import patch
|
||||
|
||||
import _test_helpers
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
# Import autocasting libs
|
||||
from torch.cuda import amp
|
||||
|
|
@ -1095,7 +1096,7 @@ def test_export_correctness_pool2d(pool_type, stride):
|
|||
torch.bfloat16,
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
LooseVersion(torch.__version__) < LooseVersion("1.10.0"),
|
||||
Version(torch.__version__) < Version("1.10.0"),
|
||||
reason="PyTorch 1.9 incompatible",
|
||||
)
|
||||
],
|
||||
|
|
@ -1141,7 +1142,7 @@ def test_gradient_correctness_minmax(operator, dim, keepdim, data_type):
|
|||
# Before 1.10 (excluded), Torch's min/max(x,y) will assign dY to y's dX if value from x and y are equal.
|
||||
# From 1.10, both x and y's dX will be dY/2. ORT follows this distribution logic, so skip below test if Torch version
|
||||
# is lower than 1.10.
|
||||
@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.10.0"), reason="PyTorch 1.9 incompatible")
|
||||
@pytest.mark.skipif(Version(torch.__version__) < Version("1.10.0"), reason="PyTorch 1.9 incompatible")
|
||||
@pytest.mark.parametrize("operator", ["min", "max"])
|
||||
def test_gradient_correctness_minmax_two_tensors(operator):
|
||||
func = getattr(torch, operator)
|
||||
|
|
@ -1322,12 +1323,62 @@ def test_gradient_correctness_reducesum(dim, keepdim):
|
|||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
|
||||
# Before PyTorch 1.11.0, the exporter will fail to register symbolic with non-empty domain.
|
||||
@pytest.mark.skipif(Version(torch.__version__) < Version("1.11.0"), reason="PyTorch 1.10 incompatible")
|
||||
@pytest.mark.parametrize("dim", [0, 1, -1])
|
||||
def test_gradient_correctness_chunk(dim):
|
||||
class NeuralNetChunk(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(NeuralNetChunk, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input):
|
||||
return input.chunk(3, dim=self.dim)
|
||||
|
||||
device = "cuda"
|
||||
pt_model = NeuralNetChunk(dim).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="chunk_model"))
|
||||
|
||||
def run_step(model, input):
|
||||
y1, y2, y3 = model(input)
|
||||
loss = y1.sum() + y2.sum() + y3.sum()
|
||||
loss.backward()
|
||||
return y1, y2, y3
|
||||
|
||||
N, D, H = 16, 17, 18
|
||||
for _ in range(10):
|
||||
input = torch.rand((N, D, H), device=device, requires_grad=True)
|
||||
pt_y1, pt_y2, pt_y3 = run_step(pt_model, input)
|
||||
ort_y1, ort_y2, ort_y3 = run_step(ort_model, input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_y1, pt_y1)
|
||||
_test_helpers.assert_values_are_close(ort_y2, pt_y2)
|
||||
_test_helpers.assert_values_are_close(ort_y3, pt_y3)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
|
||||
model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
has_split = False
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "Split":
|
||||
has_split = True
|
||||
break
|
||||
assert has_split
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
|
||||
|
||||
|
||||
# In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that
|
||||
# contains ReduceProd will fail to run, for example, "sec,sm->ecm", "sec,ecm->sm".
|
||||
# Currently skip these cases and test_gradient_correctness_einsum_2,
|
||||
# will enable these tests again once the issue in PyTorch is fixed.
|
||||
skip_torch_1_11 = pytest.mark.skipif(
|
||||
LooseVersion(torch.__version__) >= LooseVersion("1.11.0"), reason="PyTorch 1.11 incompatible"
|
||||
Version(torch.__version__) >= Version("1.11.0"), reason="PyTorch 1.11 incompatible"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -4688,7 +4739,7 @@ def test_ortmodule_ortmodule_method_attribute_copy():
|
|||
assert type(out2.grad_fn).__name__ == "_ORTModuleFunctionBackward"
|
||||
assert (
|
||||
type(out3.grad_fn).__name__ == "AddmmBackward0"
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.10.0")
|
||||
if Version(torch.__version__) >= Version("1.10.0")
|
||||
else "AddmmBackward"
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue