From 173bcdbc7189dd43f4e16029fb5ebcffee0c92bd Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 19 Jul 2022 08:10:46 +0800 Subject: [PATCH] [CUDA] Split/Concat Kernel Optimization (#12175) * split concat optimization * bugfix * fix ut * deprecate LooseVersion --- .../core/providers/cuda/tensor/concat.cc | 59 ++-- .../core/providers/cuda/tensor/concat_impl.cu | 266 ++++++++--------- .../core/providers/cuda/tensor/concat_impl.h | 27 +- .../core/providers/cuda/tensor/split.cc | 44 +-- .../core/providers/cuda/tensor/split_impl.cu | 275 ++++++++---------- .../core/providers/cuda/tensor/split_impl.h | 30 +- .../ortmodule/_custom_op_symbolic_registry.py | 33 ++- .../python/orttraining_test_ortmodule_api.py | 61 +++- 8 files changed, 394 insertions(+), 401 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/concat.cc b/onnxruntime/core/providers/cuda/tensor/concat.cc index 8654a0c347..c6b3b27a52 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat.cc +++ b/onnxruntime/core/providers/cuda/tensor/concat.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "concat.h" -#include "concat_impl.h" +#include "core/providers/cuda/tensor/concat.h" + +#include "core/providers/cuda/tensor/concat_impl.h" namespace onnxruntime { namespace cuda { @@ -63,31 +64,41 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { axis_dimension_input_output_mapping.at(index++) = i; } } - std::vector concat_sizes_range(concat_sizes); - for (size_t i = 1; i < concat_sizes_range.size(); ++i) { - concat_sizes_range[i] += concat_sizes_range[i - 1]; - } - CudaAsyncBuffer concat_sizes_gpu(this, concat_sizes); - CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); - CudaAsyncBuffer concat_sizes_range_gpu(this, concat_sizes_range); - ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu()); - ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu()); - ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu()); - ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu()); + auto element_bytes = p.output_tensor->DataType()->Size(); int block_size_inside_axis_dim = static_cast(p.output_axis_pitch / p.output_tensor->Shape()[p.axis]); int block_size_including_axis_dim = static_cast(p.output_axis_pitch); - auto element_bytes = p.output_tensor->DataType()->Size(); - ORT_RETURN_IF_ERROR(ConcatImpl(Stream(), - element_bytes, - block_size_including_axis_dim, - block_size_inside_axis_dim, - concat_sizes_gpu.GpuPtr(), - concat_sizes_range_gpu.GpuPtr(), - axis_dimension_input_output_mapping_gpu.GpuPtr(), - p.output_tensor->MutableDataRaw(), - input_ptr.GpuPtr(), - p.output_num_elements)); + if (std::all_of(concat_sizes.begin(), concat_sizes.end(), [&](int64_t size) { return size == concat_sizes[0]; })) { + if (input_count <= 32) { + TArray input_ptr_array(input_count); + for (int i = 0; i < input_count; ++i) input_ptr_array[i] = input_ptr_cpuspan[i]; + ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl( + Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0], + p.output_tensor->MutableDataRaw(), input_ptr_array, static_cast(p.output_num_elements))); + } else { + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu()); + ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl( + Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0], + p.output_tensor->MutableDataRaw(), input_ptr.GpuPtr(), static_cast(p.output_num_elements))); + } + } else { + CudaAsyncBuffer concat_sizes_gpu(this, concat_sizes); + CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); + std::vector concat_sizes_range(concat_sizes); + for (size_t i = 1; i < concat_sizes_range.size(); ++i) { + concat_sizes_range[i] += concat_sizes_range[i - 1]; + } + CudaAsyncBuffer concat_sizes_range_gpu(this, concat_sizes_range); + ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu()); + ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu()); + ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu()); + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu()); + ORT_RETURN_IF_ERROR(ConcatImpl(Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, + concat_sizes_gpu.GpuPtr(), concat_sizes_range_gpu.GpuPtr(), + axis_dimension_input_output_mapping_gpu.GpuPtr(), p.output_tensor->MutableDataRaw(), + input_ptr.GpuPtr(), static_cast(p.output_num_elements))); + } + return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/concat_impl.cu b/onnxruntime/core/providers/cuda/tensor/concat_impl.cu index 51c8444196..84e1e76fae 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/concat_impl.cu @@ -1,89 +1,79 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cuda/tensor/concat_impl.h" + #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" -#include "concat_impl.h" namespace onnxruntime { namespace cuda { +namespace { +#ifdef USE_ROCM +constexpr int kNumElementsPerThread = 2; +constexpr int kNumThreadsPerBlock = 512; +#else +constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread; +constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock; +#endif +} // namespace + // concat dimension are same for all inputs -template +template __global__ void _ConcatKernelSameConcatDim(const fast_divmod block_size_including_axis_dim_div, - const fast_divmod block_size_inside_axis_dim_div, - const fast_divmod concat_dim_size, - T* output_data, - InputIndexToMemoryMap input_ptr, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - CUDA_LONG input_pos = 0; + const fast_divmod block_size_inside_axis_dim_div, + const fast_divmod concat_dim_size, T* output_data, InputDataArray input_data, + const CUDA_LONG N) { + CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x; + T value[kNumElementsPerThread]; - int outer_block_index = 0; - int block_index = 0; - int offset = 0; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + int outer_block_index, block_index, offset, input_index, block_offset; + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); + block_size_inside_axis_dim_div.divmod(offset, block_index, offset); + concat_dim_size.divmod(block_index, input_index, block_offset); + CUDA_LONG input_pos = + (outer_block_index * concat_dim_size.d_ + block_offset) * block_size_inside_axis_dim_div.d_ + offset; + value[i] = reinterpret_cast(input_data[input_index])[input_pos]; + id += kNumThreadsPerBlock; + } + } - block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); - block_size_inside_axis_dim_div.divmod(offset, block_index, offset); - - int input_index = 0; - int block_offset = 0; - concat_dim_size.divmod(block_index, input_index, block_offset); - - input_pos = (outer_block_index * concat_dim_size.d_ + block_offset) * - block_size_inside_axis_dim_div.d_ + - offset; - - output_data[id] = reinterpret_cast(input_ptr[input_index])[input_pos]; + id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + output_data[id] = value[i]; + id += kNumThreadsPerBlock; + } + } } -template -Status ConcatSameConcatDimImpl(cudaStream_t stream, - const size_t element_bytes, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t concat_size, - void* output_data, - const InputIndexToMemoryMap input_ptr, - const size_t N) { - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - +template +Status ConcatSameConcatDimImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data, + const InputDataArray input_data, const size_t output_size) { + CUDA_LONG N = static_cast(output_size); + int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock); fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim); fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim); fast_divmod concat_dim_size = fast_divmod(static_cast(concat_size)); switch (element_bytes) { - case sizeof(int8_t): - _ConcatKernelSameConcatDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_dim_size, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; - case sizeof(int16_t): - _ConcatKernelSameConcatDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_dim_size, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; - case sizeof(int32_t): - _ConcatKernelSameConcatDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_dim_size, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; - case sizeof(int64_t): - _ConcatKernelSameConcatDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_dim_size, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; +#define CASE_ELEMENT_TYPE(type) \ + case sizeof(type): { \ + _ConcatKernelSameConcatDim<<>>( \ + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_dim_size, \ + reinterpret_cast::MappedType*>(output_data), input_data, N); \ + } break + CASE_ELEMENT_TYPE(int8_t); + CASE_ELEMENT_TYPE(int16_t); + CASE_ELEMENT_TYPE(int32_t); + CASE_ELEMENT_TYPE(int64_t); +#undef CASE_ELEMENT_TYPE default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator"); } @@ -92,103 +82,77 @@ Status ConcatSameConcatDimImpl(cudaStream_t stream, } // input tensors addresses in device memory -template Status ConcatSameConcatDimImpl(cudaStream_t stream, - const size_t element_bytes, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t concat_size, - void* output_data, - const void** input_ptr, - const size_t N); +template Status ConcatSameConcatDimImpl(cudaStream_t stream, const size_t element_bytes, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t concat_size, + void* output_data, const void** input_data, + const size_t output_size); // input tensor addresses passed by value -template Status ConcatSameConcatDimImpl>(cudaStream_t stream, - const size_t element_bytes, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t concat_size, - void* output_data, - TArray input_ptr, - const size_t N); +template Status ConcatSameConcatDimImpl>(cudaStream_t stream, const size_t element_bytes, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t concat_size, void* output_data, + TArray input_data, + const size_t output_size); template __global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_div, - const fast_divmod block_size_inside_axis_dim_div, - const int64_t* concat_sizes, - const int64_t* concat_sizes_range, - const int64_t* axis_dimension_input_output_mapping, - T* output_data, - const void** input_ptr, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - CUDA_LONG input_pos = 0; + const fast_divmod block_size_inside_axis_dim_div, const int64_t* concat_sizes, + const int64_t* concat_sizes_range, const int64_t* axis_dimension_input_output_mapping, + T* output_data, const void** input_data, const CUDA_LONG N) { + CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x; + T value[kNumElementsPerThread]; - int outer_block_index = 0; - int block_index = 0; - int offset = 0; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + int outer_block_index, block_index, offset; + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); + block_size_inside_axis_dim_div.divmod(offset, block_index, offset); + int input_index = axis_dimension_input_output_mapping[block_index]; + int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1]; + int block_offset = block_index - static_cast(range_left); + CUDA_LONG input_pos = + (outer_block_index * concat_sizes[input_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset; + value[i] = reinterpret_cast(input_data[input_index])[input_pos]; + id += kNumThreadsPerBlock; + } + } - block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); - block_size_inside_axis_dim_div.divmod(offset, block_index, offset); - - int input_index = axis_dimension_input_output_mapping[block_index]; - int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1]; - int block_offset = block_index - range_left; - - input_pos = (outer_block_index * concat_sizes[input_index] + block_offset) * - block_size_inside_axis_dim_div.d_ + - offset; - - output_data[id] = reinterpret_cast(input_ptr[input_index])[input_pos]; + id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + output_data[id] = value[i]; + id += kNumThreadsPerBlock; + } + } } -Status ConcatImpl(cudaStream_t stream, - const size_t element_bytes, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t* concat_sizes, - const int64_t* concat_sizes_range, - const int64_t* axis_dimension_input_output_mapping, - void* output_data, - const void** input_ptr, - const size_t N) { - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - +Status ConcatImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range, + const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data, + const size_t output_size) { + CUDA_LONG N = static_cast(output_size); + int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock); fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim); fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim); switch (element_bytes) { - case sizeof(int8_t): - _ConcatKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; - case sizeof(int16_t): - _ConcatKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; - case sizeof(int32_t): - _ConcatKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; - case sizeof(int64_t): - _ConcatKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - concat_sizes, concat_sizes_range, axis_dimension_input_output_mapping, - reinterpret_cast(output_data), - input_ptr, - (CUDA_LONG)N); - break; +#define CASE_ELEMENT_TYPE(type) \ + case sizeof(type): { \ + _ConcatKernel<<>>( \ + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, concat_sizes, concat_sizes_range, \ + axis_dimension_input_output_mapping, reinterpret_cast::MappedType*>(output_data), input_data, \ + N); \ + } break; + CASE_ELEMENT_TYPE(int8_t); + CASE_ELEMENT_TYPE(int16_t); + CASE_ELEMENT_TYPE(int32_t); + CASE_ELEMENT_TYPE(int64_t); +#undef CASE_ELEMENT_TYPE default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator"); } diff --git a/onnxruntime/core/providers/cuda/tensor/concat_impl.h b/onnxruntime/core/providers/cuda/tensor/concat_impl.h index 84a7aa4651..e4b2417cfc 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/concat_impl.h @@ -9,26 +9,15 @@ namespace onnxruntime { namespace cuda { -template -Status ConcatSameConcatDimImpl(cudaStream_t stream, - const size_t element_bytes, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t concat_size, - void* output_data, - const InputIndexToMemoryMap input_ptr, - const size_t N); +template +Status ConcatSameConcatDimImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t concat_size, void* output_data, + const InputDataArray input_data, const size_t output_size); -Status ConcatImpl(cudaStream_t stream, - const size_t element_bytes, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t* concat_sizes, - const int64_t* concat_sizes_range, - const int64_t* axis_dimension_input_output_mapping, - void* output_data, - const void** input_ptr, - const size_t N); +Status ConcatImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t* concat_sizes, const int64_t* concat_sizes_range, + const int64_t* axis_dimension_input_output_mapping, void* output_data, const void** input_data, + const size_t output_size); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index ae6d2830e3..7e26c6ff5a 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cuda/tensor/split.h" + #include "core/providers/cuda/tensor/split_impl.h" #include "core/providers/cpu/tensor/utils.h" @@ -34,7 +35,7 @@ ONNX_OPERATOR_KERNEL_EX(Split, Status Split::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input_tensor = ctx->Input(0); - ORT_ENFORCE(nullptr != input_tensor); + ORT_ENFORCE(input_tensor); auto& input_shape = input_tensor->Shape(); auto num_outputs = ctx->OutputCount(); int64_t axis = HandleNegativeAxis(axis_, input_shape.NumDimensions()); @@ -44,7 +45,7 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const { std::vector split_sizes(num_outputs); const Tensor* split_tensor = ctx->Input(1); - if (split_tensor != nullptr) { + if (split_tensor) { ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "An split tensor must be a vector tensor."); auto nDims = static_cast(split_tensor->Shape()[0]); const int64_t* data = split_tensor->template Data(); @@ -83,35 +84,38 @@ Status Split::ComputeInternal(OpKernelContext* ctx) const { } } - if (input_tensor->Shape().Size() > 0) { - ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu()); + if (input_tensor->Shape().Size() <= 0) return Status::OK(); + size_t element_size = input_tensor->DataType()->Size(); + if (std::all_of(split_sizes.begin(), split_sizes.end(), [&](int64_t size) { return size == split_sizes[0]; })) { + if (num_outputs <= 32) { + TArray output_ptr_array(num_outputs); + for (int i = 0; i < num_outputs; ++i) output_ptr_array[i] = output_ptr_span[i]; + ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), element_size, block_size_including_axis_dim, + block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data, + output_ptr_array, static_cast(input_shape.Size()))); + } else { + ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu()); + ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), element_size, block_size_including_axis_dim, + block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data, + output_ptr.GpuPtr(), static_cast(input_shape.Size()))); + } + } else { + ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu()); CudaAsyncBuffer split_sizes_gpu(this, split_sizes); ORT_RETURN_IF_ERROR(split_sizes_gpu.CopyToGpu()); - std::vector split_sizes_range(split_sizes); for (size_t i = 1; i < split_sizes_range.size(); ++i) { split_sizes_range[i] += split_sizes_range[i - 1]; } - CudaAsyncBuffer split_sizes_range_gpu(this, split_sizes_range); ORT_RETURN_IF_ERROR(split_sizes_range_gpu.CopyToGpu()); - CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu()); - - size_t element_size = input_tensor->DataType()->Size(); - ORT_RETURN_IF_ERROR(SplitImpl(Stream(), - element_size, - block_size_including_axis_dim, - block_size_inside_axis_dim, - split_sizes_gpu.GpuPtr(), - split_sizes_range_gpu.GpuPtr(), - axis_dimension_input_output_mapping_gpu.GpuPtr(), - num_outputs, - input_data, - output_ptr.GpuPtr(), - input_shape.Size())); + ORT_RETURN_IF_ERROR(SplitImpl(Stream(), element_size, block_size_including_axis_dim, block_size_inside_axis_dim, + split_sizes_gpu.GpuPtr(), split_sizes_range_gpu.GpuPtr(), + axis_dimension_input_output_mapping_gpu.GpuPtr(), num_outputs, input_data, + output_ptr.GpuPtr(), static_cast(input_shape.Size()))); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index 82dd4f9c47..b0ff856a43 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -1,91 +1,79 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cuda/tensor/split_impl.h" + #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" -#include "split_impl.h" namespace onnxruntime { namespace cuda { -template +namespace { +#ifdef USE_ROCM +constexpr int kNumElementsPerThread = 2; +constexpr int kNumThreadsPerBlock = 512; +#else +constexpr int kNumElementsPerThread = GridDim::maxElementsPerThread; +constexpr int kNumThreadsPerBlock = GridDim::maxThreadsPerBlock; +#endif +} // namespace + +template __global__ void _SplitKernelSameSplitDim(const fast_divmod block_size_including_axis_dim_div, - const fast_divmod block_size_inside_axis_dim_div, - const fast_divmod split_dim_size, - const int num_outputs, - const T* input_data, - OutputIndexToMemoryMap output_ptr, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - CUDA_LONG output_pos = 0; + const fast_divmod block_size_inside_axis_dim_div, + const fast_divmod split_dim_size, const int num_outputs, const T* input_data, + OutputDataArray output_data, const CUDA_LONG N) { + CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x; + T value[kNumElementsPerThread]; - int outer_block_index = 0; - int block_index = 0; - int offset = 0; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + value[i] = input_data[id]; + id += kNumThreadsPerBlock; + } + } - block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); - block_size_inside_axis_dim_div.divmod(offset, block_index, offset); - - int output_index = 0; - int block_offset = 0; - split_dim_size.divmod(block_index, output_index, block_offset); - - output_pos = (outer_block_index * split_dim_size.d_ + block_offset) * - block_size_inside_axis_dim_div.d_ + - offset; - - reinterpret_cast(output_ptr[output_index])[output_pos] = input_data[id]; + id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + int outer_block_index, block_index, offset, output_index, block_offset; + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); + block_size_inside_axis_dim_div.divmod(offset, block_index, offset); + split_dim_size.divmod(block_index, output_index, block_offset); + CUDA_LONG output_pos = + (outer_block_index * split_dim_size.d_ + block_offset) * block_size_inside_axis_dim_div.d_ + offset; + reinterpret_cast(output_data[output_index])[output_pos] = value[i]; + id += kNumThreadsPerBlock; + } + } } -template -Status SplitSameSplitDimImpl(cudaStream_t stream, - const size_t element_size, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t split_size, - const int num_outputs, - const void* input_data, - OutputIndexToMemoryMap output_ptr, - const size_t N) { - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - +template +Status SplitSameSplitDimImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t split_size, const int num_outputs, + const void* input_data, OutputDataArray output_data, const size_t input_size) { + CUDA_LONG N = static_cast(input_size); + int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock); fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim); fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim); - fast_divmod split_size_div = fast_divmod((int)split_size); + fast_divmod split_size_div = fast_divmod(static_cast(split_size)); switch (element_size) { - case sizeof(int8_t): - _SplitKernelSameSplitDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_size_div, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; - case sizeof(int16_t): - _SplitKernelSameSplitDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_size_div, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; - case sizeof(int32_t): - _SplitKernelSameSplitDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_size_div, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; - case sizeof(int64_t): - _SplitKernelSameSplitDim<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_size_div, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; +#define CASE_ELEMENT_TYPE(type) \ + case sizeof(type): { \ + _SplitKernelSameSplitDim<<>>( \ + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, split_size_div, num_outputs, \ + reinterpret_cast::MappedType*>(input_data), output_data, N); \ + } break + CASE_ELEMENT_TYPE(int8_t); + CASE_ELEMENT_TYPE(int16_t); + CASE_ELEMENT_TYPE(int32_t); + CASE_ELEMENT_TYPE(int64_t); +#undef CASE_ELEMENT_TYPE default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator"); } @@ -93,106 +81,75 @@ Status SplitSameSplitDimImpl(cudaStream_t stream, return Status::OK(); } -template Status SplitSameSplitDimImpl(cudaStream_t stream, - const size_t element_size, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t split_size, - const int num_outputs, - const void* input_data, - void** output_ptr, - const size_t N); +template Status SplitSameSplitDimImpl(cudaStream_t stream, const size_t element_size, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t split_size, + const int num_outputs, const void* input_data, void** output_data, + const size_t input_size); + +template Status SplitSameSplitDimImpl>(cudaStream_t stream, const size_t element_size, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t split_size, + const int num_outputs, const void* input_data, + TArray output_data, const size_t input_size); -template Status SplitSameSplitDimImpl>(cudaStream_t stream, - const size_t element_size, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t split_size, - const int num_outputs, - const void* input_data, - TArray output_ptr, - const size_t N); - template __global__ void _SplitKernel(const fast_divmod block_size_including_axis_dim_div, - const fast_divmod block_size_inside_axis_dim_div, - const int64_t* split_sizes, - const int64_t* split_sizes_range, - const int64_t* axis_dimension_input_output_mapping, - const int num_outputs, - const T* input_data, - void** output_ptr, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - CUDA_LONG output_pos = 0; + const fast_divmod block_size_inside_axis_dim_div, const int64_t* split_sizes, + const int64_t* split_sizes_range, const int64_t* axis_dimension_input_output_mapping, + const int num_outputs, const T* input_data, void** output_data, const CUDA_LONG N) { + CUDA_LONG start = kNumElementsPerThread * kNumThreadsPerBlock * blockIdx.x + threadIdx.x; + T value[kNumElementsPerThread]; - int outer_block_index = 0; - int block_index = 0; - int offset = 0; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + value[i] = input_data[id]; + id += kNumThreadsPerBlock; + } + } - block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); - block_size_inside_axis_dim_div.divmod(offset, block_index, offset); - - int output_index = axis_dimension_input_output_mapping[block_index]; - int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1]; - int block_offset = block_index - range_left; - - output_pos = (outer_block_index * split_sizes[output_index] + block_offset) * - block_size_inside_axis_dim_div.d_ + - offset; - - reinterpret_cast(output_ptr[output_index])[output_pos] = input_data[id]; + id = start; +#pragma unroll + for (int i = 0; i < kNumElementsPerThread; ++i) { + if (id < N) { + int outer_block_index, block_index, offset; + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); + block_size_inside_axis_dim_div.divmod(offset, block_index, offset); + int output_index = axis_dimension_input_output_mapping[block_index]; + int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1]; + int block_offset = block_index - static_cast(range_left); + CUDA_LONG output_pos = + (outer_block_index * split_sizes[output_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset; + reinterpret_cast(output_data[output_index])[output_pos] = value[i]; + id += kNumThreadsPerBlock; + } + } } -Status SplitImpl(cudaStream_t stream, - const size_t element_size, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t* split_sizes, - const int64_t* split_sizes_range, - const int64_t* axis_dimension_input_output_mapping, - const int num_outputs, - const void* input_data, - void** output_ptr, - const size_t N) { - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - +Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t* split_sizes, const int64_t* split_sizes_range, + const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data, + void** output_data, const size_t input_size) { + CUDA_LONG N = static_cast(input_size); + int blocksPerGrid = CeilDiv(N, kNumElementsPerThread * kNumThreadsPerBlock); fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim); fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim); switch (element_size) { - case sizeof(int8_t): - _SplitKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; - case sizeof(int16_t): - _SplitKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; - case sizeof(int32_t): - _SplitKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; - case sizeof(int64_t): - _SplitKernel<<>>( - block_size_including_axis_dim_div, block_size_inside_axis_dim_div, - split_sizes, split_sizes_range, axis_dimension_input_output_mapping, num_outputs, - reinterpret_cast::MappedType*>(input_data), - output_ptr, - (CUDA_LONG)N); - break; +#define CASE_ELEMENT_TYPE(type) \ + case sizeof(type): { \ + _SplitKernel<<>>( \ + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, split_sizes, split_sizes_range, \ + axis_dimension_input_output_mapping, num_outputs, \ + reinterpret_cast::MappedType*>(input_data), output_data, N); \ + } break + CASE_ELEMENT_TYPE(int8_t); + CASE_ELEMENT_TYPE(int16_t); + CASE_ELEMENT_TYPE(int32_t); + CASE_ELEMENT_TYPE(int64_t); +#undef CASE_ELEMENT_TYPE default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator"); } diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.h b/onnxruntime/core/providers/cuda/tensor/split_impl.h index 41301c9d55..16961cfb7d 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.h @@ -9,29 +9,15 @@ namespace onnxruntime { namespace cuda { -template -Status SplitSameSplitDimImpl(cudaStream_t stream, - const size_t element_size, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t split_size, - const int num_outputs, - const void* input_data, - OutputIndexToMemoryMap output_ptr, - const size_t N); +template +Status SplitSameSplitDimImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t split_size, const int num_outputs, + const void* input_data, OutputDataArray output_data, const size_t input_size); - -Status SplitImpl(cudaStream_t stream, - const size_t element_size, - const int block_size_including_axis_dim, - const int block_size_inside_axis_dim, - const int64_t* split_sizes, - const int64_t* split_sizes_range, - const int64_t* axis_dimension_input_output_mapping, - const int num_outputs, - const void* input_data, - void** output_ptr, - const size_t N); +Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, const int64_t* split_sizes, const int64_t* split_sizes_range, + const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data, + void** output_data, const size_t input_size); } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 526a6d559f..a3192b5404 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -5,6 +5,7 @@ import torch import torch.onnx.symbolic_helper as sym_help +from packaging.version import Version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args @@ -20,7 +21,9 @@ class CustomOpSymbolicRegistry: def register_all(cls): for name, fn in cls._SYMBOLICS.items(): # Symbolic name is in format: domain::name - register_custom_op_symbolic(name, fn, 1) + # Exporter will fail to register symbolic with non-empty domain when torch version is < 1.11.0. + if Version(torch.__version__) >= Version("1.11.0") or name.startswith("::"): + register_custom_op_symbolic(name, fn, 1) def register_symbolic(name, domain=""): @@ -220,6 +223,34 @@ def squeeze(g, self, dim=None): return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim]) +# Exporter's prim::ConstantChunk uses multiple Slice nodes, which is fine for inference. +# For training, the gradient graph will be multiple SliceGrad and one Sum, which is inefficient compared to +# exporting to Split with SplitGrad as gradient graph. +@register_symbolic("ConstantChunk", "prim") +def prim_ConstantChunk(g, self, chunks, dim): + input_shape_dim = g.op( + "Gather", g.op("Shape", self), g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)), axis_i=0 + ) + chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)) + chunk_dim = g.op( + "Div", + g.op("Add", input_shape_dim, chunk_size_minus_1), + g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)), + ) + return g.op( + "Split", + self, + g.op( + "Concat", + g.op("Expand", chunk_dim, chunk_size_minus_1), + g.op("Sub", input_shape_dim, g.op("Mul", chunk_dim, chunk_size_minus_1)), + axis_i=0, + ), + axis_i=dim, + outputs=chunks, + ) + + # For torch.einsum. def parse_equation(equation): pos_comma = equation.find(",") diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 3ac07ccf6d..01a47459e3 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -11,14 +11,15 @@ import random import tempfile import warnings from collections import OrderedDict, namedtuple -from distutils.version import LooseVersion from inspect import signature from time import sleep from unittest.mock import patch import _test_helpers +import onnx import pytest import torch +from packaging.version import Version # Import autocasting libs from torch.cuda import amp @@ -1095,7 +1096,7 @@ def test_export_correctness_pool2d(pool_type, stride): torch.bfloat16, marks=[ pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("1.10.0"), + Version(torch.__version__) < Version("1.10.0"), reason="PyTorch 1.9 incompatible", ) ], @@ -1141,7 +1142,7 @@ def test_gradient_correctness_minmax(operator, dim, keepdim, data_type): # Before 1.10 (excluded), Torch's min/max(x,y) will assign dY to y's dX if value from x and y are equal. # From 1.10, both x and y's dX will be dY/2. ORT follows this distribution logic, so skip below test if Torch version # is lower than 1.10. -@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.10.0"), reason="PyTorch 1.9 incompatible") +@pytest.mark.skipif(Version(torch.__version__) < Version("1.10.0"), reason="PyTorch 1.9 incompatible") @pytest.mark.parametrize("operator", ["min", "max"]) def test_gradient_correctness_minmax_two_tensors(operator): func = getattr(torch, operator) @@ -1322,12 +1323,62 @@ def test_gradient_correctness_reducesum(dim, keepdim): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +# Before PyTorch 1.11.0, the exporter will fail to register symbolic with non-empty domain. +@pytest.mark.skipif(Version(torch.__version__) < Version("1.11.0"), reason="PyTorch 1.10 incompatible") +@pytest.mark.parametrize("dim", [0, 1, -1]) +def test_gradient_correctness_chunk(dim): + class NeuralNetChunk(torch.nn.Module): + def __init__(self, dim): + super(NeuralNetChunk, self).__init__() + self.dim = dim + + def forward(self, input): + return input.chunk(3, dim=self.dim) + + device = "cuda" + pt_model = NeuralNetChunk(dim).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="chunk_model")) + + def run_step(model, input): + y1, y2, y3 = model(input) + loss = y1.sum() + y2.sum() + y3.sum() + loss.backward() + return y1, y2, y3 + + N, D, H = 16, 17, 18 + for _ in range(10): + input = torch.rand((N, D, H), device=device, requires_grad=True) + pt_y1, pt_y2, pt_y3 = run_step(pt_model, input) + ort_y1, ort_y2, ort_y3 = run_step(ort_model, input) + + _test_helpers.assert_values_are_close(ort_y1, pt_y1) + _test_helpers.assert_values_are_close(ort_y2, pt_y2) + _test_helpers.assert_values_are_close(ort_y3, pt_y3) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx")) + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx")) + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx")) + model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) + has_split = False + for node in model.graph.node: + if node.op_type == "Split": + has_split = True + break + assert has_split + os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) + os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx")) + os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx")) + os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx")) + + # In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that # contains ReduceProd will fail to run, for example, "sec,sm->ecm", "sec,ecm->sm". # Currently skip these cases and test_gradient_correctness_einsum_2, # will enable these tests again once the issue in PyTorch is fixed. skip_torch_1_11 = pytest.mark.skipif( - LooseVersion(torch.__version__) >= LooseVersion("1.11.0"), reason="PyTorch 1.11 incompatible" + Version(torch.__version__) >= Version("1.11.0"), reason="PyTorch 1.11 incompatible" ) @@ -4688,7 +4739,7 @@ def test_ortmodule_ortmodule_method_attribute_copy(): assert type(out2.grad_fn).__name__ == "_ORTModuleFunctionBackward" assert ( type(out3.grad_fn).__name__ == "AddmmBackward0" - if LooseVersion(torch.__version__) >= LooseVersion("1.10.0") + if Version(torch.__version__) >= Version("1.10.0") else "AddmmBackward" )