mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
parent
5f30be3e92
commit
e348929019
1 changed files with 32 additions and 29 deletions
|
|
@ -23,26 +23,26 @@ __global__ void _UpampleNearestKernel(const TArray<int64_t> input_pitches,
|
|||
output_div_pitches[dim].divmod(output_index, div, mod);
|
||||
output_index = mod;
|
||||
if (scales_div[dim].d_ != 1 && div > 0) {
|
||||
scales_div[dim].divmod(div, div, mod);
|
||||
scales_div[dim].divmod(div, div, mod);
|
||||
}
|
||||
input_index += input_pitches[dim] * div;
|
||||
}
|
||||
output_data[id] = input_data[input_index];
|
||||
}
|
||||
|
||||
// The following method supports a 4-D input in 'Linear mode'
|
||||
// The following method supports a 4-D input in 'Linear mode'
|
||||
// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes
|
||||
// the scale values for the outermost 2 dimensions are 1.
|
||||
// This is the common use-case where the 4-D input (batched multi-channel images)
|
||||
// This is the common use-case where the 4-D input (batched multi-channel images)
|
||||
// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
|
||||
template <typename T>
|
||||
__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2,
|
||||
const TArray<int64_t> input_pitches,
|
||||
const TArray<fast_divmod> output_div_pitches,
|
||||
const TArray<fast_divmod> scales_div,
|
||||
const T* __restrict__ input_data,
|
||||
T* __restrict__ output_data,
|
||||
const size_t N) {
|
||||
const TArray<int64_t> input_pitches,
|
||||
const TArray<fast_divmod> output_div_pitches,
|
||||
const TArray<fast_divmod> scales_div,
|
||||
const T* __restrict__ input_data,
|
||||
T* __restrict__ output_data,
|
||||
const size_t N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
CUDA_LONG input_index = 0;
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ __global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2,
|
|||
index_of_dim1 * input_pitches[1] +
|
||||
index_of_input_dim2 * input_pitches[2] +
|
||||
index_of_input_dim3;
|
||||
|
||||
|
||||
T x00 = input_data[input_index];
|
||||
T x10, x01, x11;
|
||||
|
||||
|
|
@ -78,8 +78,7 @@ __global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2,
|
|||
// It's the end in dimension 3
|
||||
x10 = x00;
|
||||
x11 = x01;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
x10 = input_data[input_index + 1];
|
||||
x11 = end_of_dim2 ? x10 : input_data[input_index + input_pitches[2] + 1];
|
||||
}
|
||||
|
|
@ -161,25 +160,25 @@ void UpampleImpl(cudaStream_t stream,
|
|||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
if (onnxruntime::UpsampleMode::NN == upsample_mode) {
|
||||
if (rank == 4) {
|
||||
_UpampleNearestKernel<T,4><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
_UpampleNearestKernel<T, 4><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
input_pitches, output_div_pitches, scales_div,
|
||||
input_data, output_data, N);
|
||||
} else if (rank == 3) {
|
||||
_UpampleNearestKernel<T,3><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
_UpampleNearestKernel<T, 3><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
input_pitches, output_div_pitches, scales_div,
|
||||
input_data, output_data, N);
|
||||
} else if (rank == 2) {
|
||||
_UpampleNearestKernel<T,2><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
_UpampleNearestKernel<T, 2><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
input_pitches, output_div_pitches, scales_div,
|
||||
input_data, output_data, N);
|
||||
} else if (rank == 1) {
|
||||
_UpampleNearestKernel<T,1><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
_UpampleNearestKernel<T, 1><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
input_pitches, output_div_pitches, scales_div,
|
||||
input_data, output_data, N);
|
||||
} else {
|
||||
ORT_THROW("Unsupported rank by the Upsample CUDA kernel");
|
||||
ORT_THROW("Unsupported rank by the Upsample CUDA kernel. Input rank: ", rank);
|
||||
}
|
||||
} else if (onnxruntime::UpsampleMode::LINEAR) {
|
||||
} else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) {
|
||||
if (rank == 4) {
|
||||
_UpampleBilinear4DInputKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
|
||||
input_dim2, input_pitches, output_div_pitches, scales_div,
|
||||
|
|
@ -189,21 +188,25 @@ void UpampleImpl(cudaStream_t stream,
|
|||
input_dim2, input_pitches, output_div_pitches, scales_div,
|
||||
input_data, output_data, N);
|
||||
} else {
|
||||
ORT_THROW("Unsupported rank by the Upsample CUDA kernel");
|
||||
ORT_THROW("Unsupported rank by the Upsample CUDA kernel. Input rank: ", rank);
|
||||
}
|
||||
} else {
|
||||
// Should never encounter this as Upsample only supports 'Nearest' and 'Linear' modes.
|
||||
// But if we do encounter this it is best to throw instead of returning silently.
|
||||
ORT_THROW("Unsupported mode for Upsample: ", upsample_mode);
|
||||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void UpampleImpl<T>(cudaStream_t stream, \
|
||||
const onnxruntime::UpsampleMode upsample_mode, \
|
||||
const size_t rank, \
|
||||
const int64_t input_dim2, \
|
||||
const TArray<int64_t>& input_pitches, \
|
||||
const TArray<fast_divmod>& output_div_pitches, \
|
||||
const TArray<fast_divmod>& scales_div, \
|
||||
const T* input_data, \
|
||||
T* output_data, \
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void UpampleImpl<T>(cudaStream_t stream, \
|
||||
const onnxruntime::UpsampleMode upsample_mode, \
|
||||
const size_t rank, \
|
||||
const int64_t input_dim2, \
|
||||
const TArray<int64_t>& input_pitches, \
|
||||
const TArray<fast_divmod>& output_div_pitches, \
|
||||
const TArray<fast_divmod>& scales_div, \
|
||||
const T* input_data, \
|
||||
T* output_data, \
|
||||
const size_t N);
|
||||
|
||||
SPECIALIZED_IMPL(float)
|
||||
|
|
|
|||
Loading…
Reference in a new issue