mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Optimize Concat and Split on CUDA to eliminate host-to-device copies when sizes are all the same (#8833)
* special case concat and split when sizes are equal * add tests for 16 and 32 inputs with same dim * add tests for 16/64 inputs on concat or 16/64 outputs on split * try eliminate windows warning * outter => outer
This commit is contained in:
parent
858989293d
commit
225439193e
8 changed files with 402 additions and 31 deletions
|
|
@ -8,6 +8,109 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
// concat dimension are same for all inputs
|
||||
template <typename T, typename InputIndexToMemoryMap>
|
||||
__global__ void _ConcatKernelSameConcatDim(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const fast_divmod concat_dim_size,
|
||||
T* output_data,
|
||||
InputIndexToMemoryMap input_ptr,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG input_pos = 0;
|
||||
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int input_index = 0;
|
||||
int block_offset = 0;
|
||||
concat_dim_size.divmod(block_index, input_index, block_offset);
|
||||
|
||||
input_pos = (outer_block_index * concat_dim_size.d_ + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
output_data[id] = reinterpret_cast<const T*>(input_ptr[input_index])[input_pos];
|
||||
}
|
||||
|
||||
template <typename InputIndexToMemoryMap>
|
||||
Status ConcatSameConcatDimImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
const InputIndexToMemoryMap input_ptr,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
|
||||
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
|
||||
fast_divmod concat_dim_size = fast_divmod(static_cast<int>(concat_size));
|
||||
switch (element_bytes) {
|
||||
case sizeof(int8_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int8_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int16_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int32_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
_ConcatKernelSameConcatDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
concat_dim_size,
|
||||
reinterpret_cast<int64_t*>(output_data),
|
||||
input_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Concat operator");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// input tensors addresses in device memory
|
||||
template Status ConcatSameConcatDimImpl<const void**>(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
const void** input_ptr,
|
||||
const size_t N);
|
||||
|
||||
// input tensor addresses passed by value
|
||||
template Status ConcatSameConcatDimImpl<TArray<const void*,32>>(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
TArray<const void*,32> input_ptr,
|
||||
const size_t N);
|
||||
|
||||
template <typename T>
|
||||
__global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
|
|
@ -20,18 +123,18 @@ __global__ void _ConcatKernel(const fast_divmod block_size_including_axis_dim_di
|
|||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG input_pos = 0;
|
||||
|
||||
int outter_block_index = 0;
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outter_block_index, offset);
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int input_index = axis_dimension_input_output_mapping[block_index];
|
||||
int64_t range_left = (input_index == 0) ? 0 : concat_sizes_range[input_index - 1];
|
||||
int block_offset = block_index - range_left;
|
||||
|
||||
input_pos = (outter_block_index * concat_sizes[input_index] + block_offset) *
|
||||
input_pos = (outer_block_index * concat_sizes[input_index] + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,16 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename InputIndexToMemoryMap>
|
||||
Status ConcatSameConcatDimImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t concat_size,
|
||||
void* output_data,
|
||||
const InputIndexToMemoryMap input_ptr,
|
||||
const size_t N);
|
||||
|
||||
Status ConcatImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
const int block_size_including_axis_dim,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,111 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, typename OutputIndexToMemoryMap>
|
||||
__global__ void _SplitKernelSameSplitDim(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
const fast_divmod split_dim_size,
|
||||
const int num_outputs,
|
||||
const T* input_data,
|
||||
OutputIndexToMemoryMap output_ptr,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG output_pos = 0;
|
||||
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int output_index = 0;
|
||||
int block_offset = 0;
|
||||
split_dim_size.divmod(block_index, output_index, block_offset);
|
||||
|
||||
output_pos = (outer_block_index * split_dim_size.d_ + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
reinterpret_cast<T*>(output_ptr[output_index])[output_pos] = input_data[id];
|
||||
}
|
||||
|
||||
template <typename OutputIndexToMemoryMap>
|
||||
Status SplitSameSplitDimImpl(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
OutputIndexToMemoryMap output_ptr,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
fast_divmod block_size_including_axis_dim_div = fast_divmod(block_size_including_axis_dim);
|
||||
fast_divmod block_size_inside_axis_dim_div = fast_divmod(block_size_inside_axis_dim);
|
||||
fast_divmod split_size_div = fast_divmod((int)split_size);
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(int8_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int8_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int16_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int32_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
_SplitKernelSameSplitDim<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
block_size_including_axis_dim_div, block_size_inside_axis_dim_div,
|
||||
split_size_div, num_outputs,
|
||||
reinterpret_cast<const ToCudaType<int64_t>::MappedType*>(input_data),
|
||||
output_ptr,
|
||||
(CUDA_LONG)N);
|
||||
break;
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Slice operator");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status SplitSameSplitDimImpl<void**>(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
void** output_ptr,
|
||||
const size_t N);
|
||||
|
||||
template Status SplitSameSplitDimImpl<TArray<void*,32>>(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
TArray<void*,32> output_ptr,
|
||||
const size_t N);
|
||||
|
||||
template <typename T>
|
||||
__global__ void _SplitKernel(const fast_divmod block_size_including_axis_dim_div,
|
||||
const fast_divmod block_size_inside_axis_dim_div,
|
||||
|
|
@ -21,18 +126,18 @@ __global__ void _SplitKernel(const fast_divmod block_size_including_axis_dim_div
|
|||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG output_pos = 0;
|
||||
|
||||
int outter_block_index = 0;
|
||||
int outer_block_index = 0;
|
||||
int block_index = 0;
|
||||
int offset = 0;
|
||||
|
||||
block_size_including_axis_dim_div.divmod(id, outter_block_index, offset);
|
||||
block_size_including_axis_dim_div.divmod(id, outer_block_index, offset);
|
||||
block_size_inside_axis_dim_div.divmod(offset, block_index, offset);
|
||||
|
||||
int output_index = axis_dimension_input_output_mapping[block_index];
|
||||
int64_t range_left = (output_index == 0) ? 0 : split_sizes_range[output_index - 1];
|
||||
int block_offset = block_index - range_left;
|
||||
|
||||
output_pos = (outter_block_index * split_sizes[output_index] + block_offset) *
|
||||
output_pos = (outer_block_index * split_sizes[output_index] + block_offset) *
|
||||
block_size_inside_axis_dim_div.d_ +
|
||||
offset;
|
||||
|
||||
|
|
@ -96,4 +201,4 @@ Status SplitImpl(cudaStream_t stream,
|
|||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -9,6 +9,18 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename OutputIndexToMemoryMap>
|
||||
Status SplitSameSplitDimImpl(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
const int block_size_inside_axis_dim,
|
||||
const int64_t split_size,
|
||||
const int num_outputs,
|
||||
const void* input_data,
|
||||
OutputIndexToMemoryMap output_ptr,
|
||||
const size_t N);
|
||||
|
||||
|
||||
Status SplitImpl(cudaStream_t stream,
|
||||
const size_t element_size,
|
||||
const int block_size_including_axis_dim,
|
||||
|
|
@ -22,4 +34,4 @@ Status SplitImpl(cudaStream_t stream,
|
|||
const size_t N);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -67,6 +67,52 @@ TEST(ConcatTrainingOpTest, Concat3D_same_len) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
void Setup_Concat3D_same_len_N_inputs(OpTester& test, const int num_inputs) {
|
||||
|
||||
test.AddAttribute("axis", int64_t{1});
|
||||
|
||||
std::vector<int64_t> dims{2, 2, 2};
|
||||
|
||||
std::vector<float> idata(8);
|
||||
std::vector<float> odata(8*num_inputs);
|
||||
auto odata_at = [&odata, num_inputs] (int i, int j, int k) -> std::vector<float>::iterator {
|
||||
return std::next(odata.begin(), 4*num_inputs*i + 2*j + k);
|
||||
};
|
||||
|
||||
float counter = 1.0f;
|
||||
std::stringstream ss;
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
|
||||
std::iota(idata.begin(), idata.end(), counter);
|
||||
counter += (float)idata.size();
|
||||
|
||||
ss.str("");
|
||||
ss << "input" << i;
|
||||
test.AddInput<float>(ss.str().c_str(), dims, idata);
|
||||
|
||||
std::copy(idata.begin(), idata.begin() + 4, odata_at(0,2*i,0));
|
||||
std::copy(idata.begin() + 4, idata.end(), odata_at(1,2*i,0));
|
||||
}
|
||||
|
||||
std::vector<int64_t> per_input_length(num_inputs, 2);
|
||||
test.AddOutput<float>("concat_result", {2, 2*num_inputs, 2}, odata);
|
||||
test.AddOutput<int64_t>("per_input_length", {num_inputs}, per_input_length);
|
||||
}
|
||||
|
||||
// <= 32 inputs tests passes input tensor addresses as kernel args
|
||||
TEST(ConcatTrainingOpTest, Concat3D_same_len_16_inputs) {
|
||||
OpTester test("ConcatTraining", 1, kMSDomain);
|
||||
Setup_Concat3D_same_len_N_inputs(test, 16);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
// > 32 inputs tests passes input tensor addresses in device buffer
|
||||
TEST(ConcatTrainingOpTest, Concat3D_same_len_64_inputs) {
|
||||
OpTester test("ConcatTraining", 1, kMSDomain);
|
||||
Setup_Concat3D_same_len_N_inputs(test, 64);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ConcatTrainingOpTest, Concat2D_optional_output1) {
|
||||
OpTester test("ConcatTraining", 1, kMSDomain);
|
||||
test.AddAttribute("axis", int64_t{1});
|
||||
|
|
|
|||
|
|
@ -59,6 +59,40 @@ TEST(SplitTrainingOpTest, Axis0EqualSplitFloat) {
|
|||
SplitTrainingOpTester<float>(axis, {}, input, outputs);
|
||||
}
|
||||
|
||||
std::tuple<ShapeAndFloatData, std::vector<ShapeAndFloatData>>
|
||||
Setup_Axis0EqualSplitFloat_N_inputs(const int num_outputs) {
|
||||
|
||||
float counter = 1.0f;
|
||||
std::vector<float> data(4*num_outputs);
|
||||
std::iota(data.begin(), data.end(), counter);
|
||||
ShapeAndFloatData input = {{2*num_outputs, 2}, data};
|
||||
|
||||
data.resize(4);
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
std::iota(data.begin(), data.end(), counter);
|
||||
outputs.push_back({{2, 2}, data});
|
||||
counter += (float)data.size();
|
||||
}
|
||||
|
||||
// due to const on ShapeAndFloatData
|
||||
return std::make_tuple(input, outputs);
|
||||
}
|
||||
|
||||
// <=32 with same sizes passes output addresses as kernel args
|
||||
TEST(SplitTrainingOpTest, Axis0EqualSplitFloat_16_outputs) {
|
||||
const int64_t axis = 0;
|
||||
auto io = Setup_Axis0EqualSplitFloat_N_inputs(16);
|
||||
SplitTrainingOpTester<float>(axis, {}, std::get<0>(io), std::get<1>(io));
|
||||
}
|
||||
|
||||
// > 32 with same sizes passes output addresses as device buffer
|
||||
TEST(SplitTrainingOpTest, Axis0EqualSplitFloat_64_outputs) {
|
||||
const int64_t axis = 0;
|
||||
auto io = Setup_Axis0EqualSplitFloat_N_inputs(64);
|
||||
SplitTrainingOpTester<float>(axis, {}, std::get<0>(io), std::get<1>(io));
|
||||
}
|
||||
|
||||
TEST(SplitTrainingOpTest, Axis0UnequalSplitFloat) {
|
||||
const int64_t axis = 0;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
|
|
|||
|
|
@ -50,26 +50,59 @@ Status ConcatTraining::ComputeInternal(OpKernelContext* ctx) const {
|
|||
concat_sizes_range[i] += concat_sizes_range[i - 1];
|
||||
}
|
||||
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_gpu(this, concat_sizes);
|
||||
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_range_gpu(this, concat_sizes_range);
|
||||
concat_sizes_gpu.CopyToGpu();
|
||||
axis_dimension_input_output_mapping_gpu.CopyToGpu();
|
||||
concat_sizes_range_gpu.CopyToGpu();
|
||||
input_ptr.CopyToGpu();
|
||||
auto element_bytes = p.output_tensor->DataType()->Size();
|
||||
int block_size_inside_axis_dim = static_cast<int>(p.output_axis_pitch / p.output_tensor->Shape()[p.axis]);
|
||||
int block_size_including_axis_dim = static_cast<int>(p.output_axis_pitch);
|
||||
auto element_bytes = p.output_tensor->DataType()->Size();
|
||||
ORT_RETURN_IF_ERROR(ConcatImpl(Stream(),
|
||||
if (std::all_of(concat_sizes.begin(), concat_sizes.end(), [&] (int64_t i) {return i == concat_sizes[0];})) {
|
||||
if (input_count <= 32) {
|
||||
// pass by value to avoid host-to-device copy on same stream
|
||||
TArray<const void*, 32> input_table(input_count);
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
input_table[i] = input_ptr_cpuspan[i];
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(Stream(),
|
||||
element_bytes,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
concat_sizes_gpu.GpuPtr(),
|
||||
concat_sizes_range_gpu.GpuPtr(),
|
||||
axis_dimension_input_output_mapping_gpu.GpuPtr(),
|
||||
concat_sizes[0],
|
||||
p.output_tensor->MutableDataRaw(),
|
||||
input_table,
|
||||
p.output_num_elements));
|
||||
|
||||
} else {
|
||||
// too many inputs, so copy sizes to device memory
|
||||
input_ptr.CopyToGpu();
|
||||
ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl(Stream(),
|
||||
element_bytes,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
concat_sizes[0],
|
||||
p.output_tensor->MutableDataRaw(),
|
||||
input_ptr.GpuPtr(),
|
||||
p.output_num_elements));
|
||||
}
|
||||
} else {
|
||||
// input sizes vary, copy input sizes and range metadata to device
|
||||
// todo: pass by value when few inputs
|
||||
input_ptr.CopyToGpu();
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_gpu(this, concat_sizes);
|
||||
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
|
||||
CudaAsyncBuffer<int64_t> concat_sizes_range_gpu(this, concat_sizes_range);
|
||||
concat_sizes_gpu.CopyToGpu();
|
||||
axis_dimension_input_output_mapping_gpu.CopyToGpu();
|
||||
concat_sizes_range_gpu.CopyToGpu();
|
||||
|
||||
ORT_RETURN_IF_ERROR(ConcatImpl(Stream(),
|
||||
element_bytes,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
concat_sizes_gpu.GpuPtr(),
|
||||
concat_sizes_range_gpu.GpuPtr(),
|
||||
axis_dimension_input_output_mapping_gpu.GpuPtr(),
|
||||
p.output_tensor->MutableDataRaw(),
|
||||
input_ptr.GpuPtr(),
|
||||
p.output_num_elements));
|
||||
}
|
||||
|
||||
// Create optional output tensor for 'per_input_length'
|
||||
Tensor* per_input_length_tensor = ctx->Output(1, {input_count});
|
||||
|
|
|
|||
|
|
@ -65,24 +65,51 @@ Status SplitTraining::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
|
||||
if (input_tensor->Shape().Size() > 0) {
|
||||
output_ptr.CopyToGpu();
|
||||
|
||||
CudaAsyncBuffer<int64_t> split_sizes_gpu(this, split_sizes);
|
||||
split_sizes_gpu.CopyToGpu();
|
||||
|
||||
std::vector<int64_t> split_sizes_range(split_sizes);
|
||||
for (size_t i = 1; i < split_sizes_range.size(); ++i) {
|
||||
split_sizes_range[i] += split_sizes_range[i - 1];
|
||||
}
|
||||
|
||||
CudaAsyncBuffer<int64_t> split_sizes_range_gpu(this, split_sizes_range);
|
||||
split_sizes_range_gpu.CopyToGpu();
|
||||
|
||||
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
|
||||
axis_dimension_input_output_mapping_gpu.CopyToGpu();
|
||||
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
ORT_RETURN_IF_ERROR(SplitImpl(Stream(),
|
||||
|
||||
if (std::all_of(split_sizes.begin(), split_sizes.end(), [&] (int64_t i) {return i == split_sizes[0];})) {
|
||||
if (num_outputs <= 32) {
|
||||
TArray<void*, 32> output_table(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
output_table[i] = output_ptr_span[i];
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(),
|
||||
element_size,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
split_sizes[0],
|
||||
num_outputs,
|
||||
input_data,
|
||||
output_table,
|
||||
input_shape.Size()));
|
||||
} else {
|
||||
output_ptr.CopyToGpu();
|
||||
ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(),
|
||||
element_size,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
split_sizes[0],
|
||||
num_outputs,
|
||||
input_data,
|
||||
output_ptr.GpuPtr(),
|
||||
input_shape.Size()));
|
||||
}
|
||||
} else {
|
||||
output_ptr.CopyToGpu();
|
||||
CudaAsyncBuffer<int64_t> split_sizes_gpu(this, split_sizes);
|
||||
CudaAsyncBuffer<int64_t> split_sizes_range_gpu(this, split_sizes_range);
|
||||
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping);
|
||||
split_sizes_gpu.CopyToGpu();
|
||||
split_sizes_range_gpu.CopyToGpu();
|
||||
axis_dimension_input_output_mapping_gpu.CopyToGpu();
|
||||
|
||||
ORT_RETURN_IF_ERROR(SplitImpl(Stream(),
|
||||
element_size,
|
||||
block_size_including_axis_dim,
|
||||
block_size_inside_axis_dim,
|
||||
|
|
@ -93,6 +120,7 @@ Status SplitTraining::ComputeInternal(OpKernelContext* ctx) const {
|
|||
input_data,
|
||||
output_ptr.GpuPtr(),
|
||||
input_shape.Size()));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
Loading…
Reference in a new issue