diff --git a/onnxruntime/core/providers/cuda/tensor/concat_impl.cu b/onnxruntime/core/providers/cuda/tensor/concat_impl.cu index 6047f12189..51c8444196 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/concat_impl.cu @@ -8,6 +8,109 @@ namespace onnxruntime { namespace cuda { +// concat dimension are same for all inputs +template +__global__ void _ConcatKernelSameConcatDim(const fast_divmod block_size_including_axis_dim_div, + const fast_divmod block_size_inside_axis_dim_div, + const fast_divmod concat_dim_size, + T* output_data, + InputIndexToMemoryMap input_ptr, + const CUDA_LONG N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + CUDA_LONG input_pos = 0; + + int outer_block_index = 0; + int block_index = 0; + int offset = 0; + + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); + block_size_inside_axis_dim_div.divmod(offset, block_index, offset); + + int input_index = 0; + int block_offset = 0; + concat_dim_size.divmod(block_index, input_index, block_offset); + + input_pos = (outer_block_index * concat_dim_size.d_ + block_offset) * + block_size_inside_axis_dim_div.d_ + + offset; + + output_data[id] = reinterpret_cast(input_ptr[input_index])[input_pos]; +} + +template +Status ConcatSameConcatDimImpl(cudaStream_t stream, + const size_t element_bytes, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t concat_size, + void* output_data, + const InputIndexToMemoryMap input_ptr, + const size_t N) { + int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + + fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim); + fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim); + fast_divmod concat_dim_size = fast_divmod(static_cast(concat_size)); + switch (element_bytes) { + case sizeof(int8_t): + _ConcatKernelSameConcatDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + concat_dim_size, + reinterpret_cast(output_data), + input_ptr, + (CUDA_LONG)N); + break; + case sizeof(int16_t): + _ConcatKernelSameConcatDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + concat_dim_size, + reinterpret_cast(output_data), + input_ptr, + (CUDA_LONG)N); + break; + case sizeof(int32_t): + _ConcatKernelSameConcatDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + concat_dim_size, + reinterpret_cast(output_data), + input_ptr, + (CUDA_LONG)N); + break; + case sizeof(int64_t): + _ConcatKernelSameConcatDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + concat_dim_size, + reinterpret_cast(output_data), + input_ptr, + (CUDA_LONG)N); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator"); + } + + return Status::OK(); +} + +// input tensors addresses in device memory +template Status ConcatSameConcatDimImpl(cudaStream_t stream, + const size_t element_bytes, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t concat_size, + void* output_data, + const void** input_ptr, + const size_t N); + +// input tensor addresses passed by value +template Status ConcatSameConcatDimImpl>(cudaStream_t stream, + const size_t element_bytes, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t concat_size, + void* output_data, + TArray input_ptr, + const size_t N); + template __global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_div, const fast_divmod block_size_inside_axis_dim_div, @@ -20,18 +123,18 @@ __global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_di CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_pos = 0; - int outter_block_index = 0; + int outer_block_index = 0; int block_index = 0; int offset = 0; - block_size_including_axis_dim_div.divmod(id, outter_block_index, offset); + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); block_size_inside_axis_dim_div.divmod(offset, block_index, offset); int input_index = axis_dimension_input_output_mapping[block_index]; int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1]; int block_offset = block_index - range_left; - input_pos = (outter_block_index * concat_sizes[input_index] + block_offset) * + input_pos = (outer_block_index * concat_sizes[input_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset; diff --git a/onnxruntime/core/providers/cuda/tensor/concat_impl.h b/onnxruntime/core/providers/cuda/tensor/concat_impl.h index 2a3b6ba9f9..84a7aa4651 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/concat_impl.h @@ -9,6 +9,16 @@ namespace onnxruntime { namespace cuda { +template +Status ConcatSameConcatDimImpl(cudaStream_t stream, + const size_t element_bytes, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t concat_size, + void* output_data, + const InputIndexToMemoryMap input_ptr, + const size_t N); + Status ConcatImpl(cudaStream_t stream, const size_t element_bytes, const int block_size_including_axis_dim, diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index f1565428d6..82dd4f9c47 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -8,6 +8,111 @@ namespace onnxruntime { namespace cuda { +template +__global__ void _SplitKernelSameSplitDim(const fast_divmod block_size_including_axis_dim_div, + const fast_divmod block_size_inside_axis_dim_div, + const fast_divmod split_dim_size, + const int num_outputs, + const T* input_data, + OutputIndexToMemoryMap output_ptr, + const CUDA_LONG N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + CUDA_LONG output_pos = 0; + + int outer_block_index = 0; + int block_index = 0; + int offset = 0; + + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); + block_size_inside_axis_dim_div.divmod(offset, block_index, offset); + + int output_index = 0; + int block_offset = 0; + split_dim_size.divmod(block_index, output_index, block_offset); + + output_pos = (outer_block_index * split_dim_size.d_ + block_offset) * + block_size_inside_axis_dim_div.d_ + + offset; + + reinterpret_cast(output_ptr[output_index])[output_pos] = input_data[id]; +} + +template +Status SplitSameSplitDimImpl(cudaStream_t stream, + const size_t element_size, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t split_size, + const int num_outputs, + const void* input_data, + OutputIndexToMemoryMap output_ptr, + const size_t N) { + int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + + fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim); + fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim); + fast_divmod split_size_div = fast_divmod((int)split_size); + + switch (element_size) { + case sizeof(int8_t): + _SplitKernelSameSplitDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + split_size_div, num_outputs, + reinterpret_cast::MappedType*>(input_data), + output_ptr, + (CUDA_LONG)N); + break; + case sizeof(int16_t): + _SplitKernelSameSplitDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + split_size_div, num_outputs, + reinterpret_cast::MappedType*>(input_data), + output_ptr, + (CUDA_LONG)N); + break; + case sizeof(int32_t): + _SplitKernelSameSplitDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + split_size_div, num_outputs, + reinterpret_cast::MappedType*>(input_data), + output_ptr, + (CUDA_LONG)N); + break; + case sizeof(int64_t): + _SplitKernelSameSplitDim<<>>( + block_size_including_axis_dim_div, block_size_inside_axis_dim_div, + split_size_div, num_outputs, + reinterpret_cast::MappedType*>(input_data), + output_ptr, + (CUDA_LONG)N); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator"); + } + + return Status::OK(); +} + +template Status SplitSameSplitDimImpl(cudaStream_t stream, + const size_t element_size, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t split_size, + const int num_outputs, + const void* input_data, + void** output_ptr, + const size_t N); + +template Status SplitSameSplitDimImpl>(cudaStream_t stream, + const size_t element_size, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t split_size, + const int num_outputs, + const void* input_data, + TArray output_ptr, + const size_t N); + template __global__ void _SplitKernel(const fast_divmod block_size_including_axis_dim_div, const fast_divmod block_size_inside_axis_dim_div, @@ -21,18 +126,18 @@ __global__ void _SplitKernel(const fast_divmod block_size_including_axis_dim_div CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG output_pos = 0; - int outter_block_index = 0; + int outer_block_index = 0; int block_index = 0; int offset = 0; - block_size_including_axis_dim_div.divmod(id, outter_block_index, offset); + block_size_including_axis_dim_div.divmod(id, outer_block_index, offset); block_size_inside_axis_dim_div.divmod(offset, block_index, offset); int output_index = axis_dimension_input_output_mapping[block_index]; int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1]; int block_offset = block_index - range_left; - output_pos = (outter_block_index * split_sizes[output_index] + block_offset) * + output_pos = (outer_block_index * split_sizes[output_index] + block_offset) * block_size_inside_axis_dim_div.d_ + offset; @@ -96,4 +201,4 @@ Status SplitImpl(cudaStream_t stream, } } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.h b/onnxruntime/core/providers/cuda/tensor/split_impl.h index a8fde02549..41301c9d55 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.h @@ -9,6 +9,18 @@ namespace onnxruntime { namespace cuda { +template +Status SplitSameSplitDimImpl(cudaStream_t stream, + const size_t element_size, + const int block_size_including_axis_dim, + const int block_size_inside_axis_dim, + const int64_t split_size, + const int num_outputs, + const void* input_data, + OutputIndexToMemoryMap output_ptr, + const size_t N); + + Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block_size_including_axis_dim, @@ -22,4 +34,4 @@ Status SplitImpl(cudaStream_t stream, const size_t N); } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/concat_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/concat_op_test.cc index 3db482ded4..a242384b40 100644 --- a/orttraining/orttraining/test/training_ops/cpu/tensor/concat_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/tensor/concat_op_test.cc @@ -67,6 +67,52 @@ TEST(ConcatTrainingOpTest, Concat3D_same_len) { test.Run(); } +void Setup_Concat3D_same_len_N_inputs(OpTester& test, const int num_inputs) { + + test.AddAttribute("axis", int64_t{1}); + + std::vector dims{2, 2, 2}; + + std::vector idata(8); + std::vector odata(8*num_inputs); + auto odata_at = [&odata, num_inputs] (int i, int j, int k) -> std::vector::iterator { + return std::next(odata.begin(), 4*num_inputs*i + 2*j + k); + }; + + float counter = 1.0f; + std::stringstream ss; + for (int i = 0; i < num_inputs; i++) { + + std::iota(idata.begin(), idata.end(), counter); + counter += (float)idata.size(); + + ss.str(""); + ss << "input" << i; + test.AddInput(ss.str().c_str(), dims, idata); + + std::copy(idata.begin(), idata.begin() + 4, odata_at(0,2*i,0)); + std::copy(idata.begin() + 4, idata.end(), odata_at(1,2*i,0)); + } + + std::vector per_input_length(num_inputs, 2); + test.AddOutput("concat_result", {2, 2*num_inputs, 2}, odata); + test.AddOutput("per_input_length", {num_inputs}, per_input_length); +} + +// <= 32 inputs tests passes input tensor addresses as kernel args +TEST(ConcatTrainingOpTest, Concat3D_same_len_16_inputs) { + OpTester test("ConcatTraining", 1, kMSDomain); + Setup_Concat3D_same_len_N_inputs(test, 16); + test.Run(); +} + +// > 32 inputs tests passes input tensor addresses in device buffer +TEST(ConcatTrainingOpTest, Concat3D_same_len_64_inputs) { + OpTester test("ConcatTraining", 1, kMSDomain); + Setup_Concat3D_same_len_N_inputs(test, 64); + test.Run(); +} + TEST(ConcatTrainingOpTest, Concat2D_optional_output1) { OpTester test("ConcatTraining", 1, kMSDomain); test.AddAttribute("axis", int64_t{1}); diff --git a/orttraining/orttraining/test/training_ops/cpu/tensor/split_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/tensor/split_op_test.cc index b727df0950..409a0faac2 100644 --- a/orttraining/orttraining/test/training_ops/cpu/tensor/split_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/tensor/split_op_test.cc @@ -59,6 +59,40 @@ TEST(SplitTrainingOpTest, Axis0EqualSplitFloat) { SplitTrainingOpTester(axis, {}, input, outputs); } +std::tuple> +Setup_Axis0EqualSplitFloat_N_inputs(const int num_outputs) { + + float counter = 1.0f; + std::vector data(4*num_outputs); + std::iota(data.begin(), data.end(), counter); + ShapeAndFloatData input = {{2*num_outputs, 2}, data}; + + data.resize(4); + std::vector outputs; + for (int i = 0; i < num_outputs; i++) { + std::iota(data.begin(), data.end(), counter); + outputs.push_back({{2, 2}, data}); + counter += (float)data.size(); + } + + // due to const on ShapeAndFloatData + return std::make_tuple(input, outputs); +} + +// <=32 with same sizes passes output addresses as kernel args +TEST(SplitTrainingOpTest, Axis0EqualSplitFloat_16_outputs) { + const int64_t axis = 0; + auto io = Setup_Axis0EqualSplitFloat_N_inputs(16); + SplitTrainingOpTester(axis, {}, std::get<0>(io), std::get<1>(io)); +} + +// > 32 with same sizes passes output addresses as device buffer +TEST(SplitTrainingOpTest, Axis0EqualSplitFloat_64_outputs) { + const int64_t axis = 0; + auto io = Setup_Axis0EqualSplitFloat_N_inputs(64); + SplitTrainingOpTester(axis, {}, std::get<0>(io), std::get<1>(io)); +} + TEST(SplitTrainingOpTest, Axis0UnequalSplitFloat) { const int64_t axis = 0; std::vector outputs; diff --git a/orttraining/orttraining/training_ops/cuda/tensor/concat.cc b/orttraining/orttraining/training_ops/cuda/tensor/concat.cc index 103cc91892..de5f943640 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/concat.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/concat.cc @@ -50,26 +50,59 @@ Status ConcatTraining::ComputeInternal(OpKernelContext* ctx) const { concat_sizes_range[i] += concat_sizes_range[i - 1]; } - CudaAsyncBuffer concat_sizes_gpu(this, concat_sizes); - CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); - CudaAsyncBuffer concat_sizes_range_gpu(this, concat_sizes_range); - concat_sizes_gpu.CopyToGpu(); - axis_dimension_input_output_mapping_gpu.CopyToGpu(); - concat_sizes_range_gpu.CopyToGpu(); - input_ptr.CopyToGpu(); + auto element_bytes = p.output_tensor->DataType()->Size(); int block_size_inside_axis_dim = static_cast(p.output_axis_pitch / p.output_tensor->Shape()[p.axis]); int block_size_including_axis_dim = static_cast(p.output_axis_pitch); - auto element_bytes = p.output_tensor->DataType()->Size(); - ORT_RETURN_IF_ERROR(ConcatImpl(Stream(), + if (std::all_of(concat_sizes.begin(), concat_sizes.end(), [&] (int64_t i) {return i == concat_sizes[0];})) { + if (input_count <= 32) { + // pass by value to avoid host-to-device copy on same stream + TArray input_table(input_count); + for (int i = 0; i < input_count; ++i) { + input_table[i] = input_ptr_cpuspan[i]; + } + ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(Stream(), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, - concat_sizes_gpu.GpuPtr(), - concat_sizes_range_gpu.GpuPtr(), - axis_dimension_input_output_mapping_gpu.GpuPtr(), + concat_sizes[0], + p.output_tensor->MutableDataRaw(), + input_table, + p.output_num_elements)); + + } else { + // too many inputs, so copy sizes to device memory + input_ptr.CopyToGpu(); + ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(Stream(), + element_bytes, + block_size_including_axis_dim, + block_size_inside_axis_dim, + concat_sizes[0], p.output_tensor->MutableDataRaw(), input_ptr.GpuPtr(), p.output_num_elements)); + } + } else { + // input sizes vary, copy input sizes and range metadata to device + // todo: pass by value when few inputs + input_ptr.CopyToGpu(); + CudaAsyncBuffer concat_sizes_gpu(this, concat_sizes); + CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); + CudaAsyncBuffer concat_sizes_range_gpu(this, concat_sizes_range); + concat_sizes_gpu.CopyToGpu(); + axis_dimension_input_output_mapping_gpu.CopyToGpu(); + concat_sizes_range_gpu.CopyToGpu(); + + ORT_RETURN_IF_ERROR(ConcatImpl(Stream(), + element_bytes, + block_size_including_axis_dim, + block_size_inside_axis_dim, + concat_sizes_gpu.GpuPtr(), + concat_sizes_range_gpu.GpuPtr(), + axis_dimension_input_output_mapping_gpu.GpuPtr(), + p.output_tensor->MutableDataRaw(), + input_ptr.GpuPtr(), + p.output_num_elements)); + } // Create optional output tensor for 'per_input_length' Tensor* per_input_length_tensor = ctx->Output(1, {input_count}); diff --git a/orttraining/orttraining/training_ops/cuda/tensor/split.cc b/orttraining/orttraining/training_ops/cuda/tensor/split.cc index 956c423255..6ff1f510ae 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/split.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/split.cc @@ -65,24 +65,51 @@ Status SplitTraining::ComputeInternal(OpKernelContext* ctx) const { } if (input_tensor->Shape().Size() > 0) { - output_ptr.CopyToGpu(); - - CudaAsyncBuffer split_sizes_gpu(this, split_sizes); - split_sizes_gpu.CopyToGpu(); std::vector split_sizes_range(split_sizes); for (size_t i = 1; i < split_sizes_range.size(); ++i) { split_sizes_range[i] += split_sizes_range[i - 1]; } - CudaAsyncBuffer split_sizes_range_gpu(this, split_sizes_range); - split_sizes_range_gpu.CopyToGpu(); - - CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); - axis_dimension_input_output_mapping_gpu.CopyToGpu(); - size_t element_size = input_tensor->DataType()->Size(); - ORT_RETURN_IF_ERROR(SplitImpl(Stream(), + + if (std::all_of(split_sizes.begin(), split_sizes.end(), [&] (int64_t i) {return i == split_sizes[0];})) { + if (num_outputs <= 32) { + TArray output_table(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + output_table[i] = output_ptr_span[i]; + } + ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), + element_size, + block_size_including_axis_dim, + block_size_inside_axis_dim, + split_sizes[0], + num_outputs, + input_data, + output_table, + input_shape.Size())); + } else { + output_ptr.CopyToGpu(); + ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(), + element_size, + block_size_including_axis_dim, + block_size_inside_axis_dim, + split_sizes[0], + num_outputs, + input_data, + output_ptr.GpuPtr(), + input_shape.Size())); + } + } else { + output_ptr.CopyToGpu(); + CudaAsyncBuffer split_sizes_gpu(this, split_sizes); + CudaAsyncBuffer split_sizes_range_gpu(this, split_sizes_range); + CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); + split_sizes_gpu.CopyToGpu(); + split_sizes_range_gpu.CopyToGpu(); + axis_dimension_input_output_mapping_gpu.CopyToGpu(); + + ORT_RETURN_IF_ERROR(SplitImpl(Stream(), element_size, block_size_including_axis_dim, block_size_inside_axis_dim, @@ -93,6 +120,7 @@ Status SplitTraining::ComputeInternal(OpKernelContext* ctx) const { input_data, output_ptr.GpuPtr(), input_shape.Size())); + } } return Status::OK();