mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
Replacing CudaAsyncBuffer with TArray to improve perf (#3303)
* removing using CudaAsyncBuffer * Keep CudaAsyncBuffer for these ops: non_max_suppression, cudnn_rnn_base, concat, split * fix windows build error * fix windows build error. * fix build error * fix windows build error Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
This commit is contained in:
parent
ef7b98f988
commit
fef7989866
32 changed files with 423 additions and 357 deletions
|
|
@ -27,12 +27,11 @@ namespace cuda {
|
|||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
UnaryElementwisePreparation p; \
|
||||
UnaryElementwise::Prepare(context, &p); \
|
||||
CudaAsyncBuffer<Ctx##x> func_ctx(this, MakeFuncCtx(), 1); \
|
||||
if (!std::is_same<CtxNull, Ctx##x>::value) ORT_RETURN_IF_ERROR(func_ctx.CopyToGpu()); \
|
||||
Ctx##x func_ctx = MakeFuncCtx(); \
|
||||
Impl_##x<typename ToCudaType<T>::MappedType>( \
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(p.input_tensor->template Data<T>()), \
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(p.output_tensor->template MutableData<T>()), \
|
||||
func_ctx.GpuPtr(), p.output_tensor->Shape().Size()); \
|
||||
&func_ctx, p.output_tensor->Shape().Size()); \
|
||||
\
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,12 +23,11 @@ namespace cuda {
|
|||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
UnaryElementwisePreparation p; \
|
||||
UnaryElementwise::Prepare(context, &p); \
|
||||
CudaAsyncBuffer<Ctx##x> func_ctx(this, MakeFuncCtx(), 1); \
|
||||
if (!std::is_same<CtxNull, Ctx##x>::value) ORT_RETURN_IF_ERROR(func_ctx.CopyToGpu()); \
|
||||
Ctx##x func_ctx = MakeFuncCtx(); \
|
||||
Impl_##x<typename ToCudaType<T>::MappedType>( \
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(p.input_tensor->template Data<T>()), \
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(p.output_tensor->template MutableData<T>()), \
|
||||
func_ctx.GpuPtr(), p.output_tensor->Shape().Size()); \
|
||||
&func_ctx, p.output_tensor->Shape().Size()); \
|
||||
\
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ template <typename InT, typename OutT, typename FuncT, int NumThreadsPerBlock, i
|
|||
__global__ void _UnaryElementWise(
|
||||
const InT* input_data,
|
||||
OutT* output_data,
|
||||
const FuncT& functor,
|
||||
const FuncT functor,
|
||||
CUDA_LONG N) {
|
||||
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
InT value[NumElementsPerThread];
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
|
|||
#define TOPKIMPL(T) TopKImpl<T>(this, tensor_X->Data<T>(), \
|
||||
static_cast<T*>(tensor_V->MutableDataRaw()), \
|
||||
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
|
||||
elem_nums_cuda.GpuPtr(), \
|
||||
elem_nums_cuda, \
|
||||
elem_nums.size(), \
|
||||
axis, K_, largest_, sorted_, N, dimension)
|
||||
|
||||
|
|
@ -53,8 +53,8 @@ template <bool inputk>
|
|||
Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
auto tensor_X = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(nullptr != tensor_X);
|
||||
auto rank = static_cast<int64_t>(tensor_X->Shape().NumDimensions());
|
||||
auto axis = axis_ < 0 ? rank + axis_ : axis_;
|
||||
int32_t rank = static_cast<int32_t>(tensor_X->Shape().NumDimensions());
|
||||
int32_t axis = static_cast<int32_t>(axis_ < 0 ? rank + axis_ : axis_);
|
||||
ORT_ENFORCE(axis > -1 && axis < rank);
|
||||
|
||||
if (inputk) {
|
||||
|
|
@ -80,8 +80,7 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
}
|
||||
|
||||
auto N = elem_nums[0] / dimension;
|
||||
CudaAsyncBuffer<int64_t> elem_nums_cuda(this, elem_nums);
|
||||
ORT_RETURN_IF_ERROR(elem_nums_cuda.CopyToGpu());
|
||||
TArray<int64_t> elem_nums_cuda(elem_nums);
|
||||
|
||||
auto prim_type = tensor_X->DataType()->AsPrimitiveDataType();
|
||||
if (prim_type == nullptr) {
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ struct KV {
|
|||
#define LESS(n, m) ((n) <= (m) ? (n) : (m))
|
||||
|
||||
template <typename T>
|
||||
__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) {
|
||||
__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) {
|
||||
auto tid = threadIdx.x;
|
||||
auto bid = blockIdx.x;
|
||||
extern __shared__ char shared_mem[];
|
||||
|
|
@ -192,7 +192,7 @@ __device__ void SetByte(double* d, int64_t byte) {
|
|||
}
|
||||
|
||||
template<typename T, int64_t THREADS, int64_t KPT>
|
||||
__global__ void RadixTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) {
|
||||
__global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) {
|
||||
auto tid = threadIdx.x;
|
||||
auto bid = blockIdx.x;
|
||||
extern __shared__ char shared_mem[];
|
||||
|
|
@ -342,7 +342,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
|
||||
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t offset, int64_t dimension) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
|
||||
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis];
|
||||
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
|
||||
|
|
@ -352,7 +352,7 @@ __global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, cons
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillOutput(const T* input_v, const int64_t* input_i, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
|
||||
__global__ void FillOutput(const T* input_v, const int64_t* input_i, T* output_v, int64_t* output_i, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t offset, int64_t dimension) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, K);
|
||||
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis] * K / dimension;
|
||||
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
|
||||
|
|
@ -369,7 +369,7 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
|
||||
auto aligned_K = ALIGN(K);
|
||||
auto aligned_dimension = ALIGN(dimension);
|
||||
if (aligned_dimension <= GridDim::maxThreadsPerBlock) {
|
||||
|
|
@ -419,9 +419,9 @@ Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t
|
|||
const T* input_x, \
|
||||
T* output_v, \
|
||||
int64_t* output_i, \
|
||||
const int64_t* elem_nums, \
|
||||
const TArray<int64_t>& elem_nums, \
|
||||
size_t size, \
|
||||
int64_t axis, \
|
||||
int32_t axis, \
|
||||
int64_t K, \
|
||||
int64_t largest, \
|
||||
int64_t sorted, \
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
|
||||
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -84,13 +84,17 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
|
|||
CalcEffectiveDims(input_dims, output_dims);
|
||||
int rank = gsl::narrow_cast<int>(output_dims.size());
|
||||
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, rank);
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));
|
||||
TensorPitches original_input_strides(input_dims);
|
||||
TensorPitches original_output_strides(output_dims);
|
||||
|
||||
CudaAsyncBuffer<int64_t> input_view_strides(this, rank);
|
||||
TensorPitches::Calculate(input_view_strides.CpuSpan(), input_dims);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (input_dims[i] == 1) input_view_strides.CpuSpan()[i] = 0;
|
||||
TArray<int64_t> input_strides(rank);
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i];
|
||||
}
|
||||
|
||||
TArray<fast_divmod> output_strides(rank);
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
output_strides[i] = fast_divmod(static_cast<int>(original_output_strides[i]));
|
||||
}
|
||||
|
||||
return ExpandImpl(
|
||||
|
|
@ -99,8 +103,8 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
|
|||
gsl::narrow_cast<int>(input_data_tensor.Shape().Size()),
|
||||
input_data_tensor.DataRaw(),
|
||||
output_tensor.MutableDataRaw(),
|
||||
fdm_output_strides,
|
||||
input_view_strides);
|
||||
output_strides,
|
||||
input_strides);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -50,14 +50,14 @@ __global__ void ExpandKernel(
|
|||
const int N,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const int64_t* input_view_strides) {
|
||||
const TArray<fast_divmod> output_strides,
|
||||
const TArray<int64_t> input_strides) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
|
||||
int dim, r = id, input_index = 0;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
fdm_output_strides[i].divmod(r, dim, r);
|
||||
input_index += dim * input_view_strides[i];
|
||||
output_strides[i].divmod(r, dim, r);
|
||||
input_index += dim * input_strides[i];
|
||||
}
|
||||
output_data[id] = input_data[input_index];
|
||||
}
|
||||
|
|
@ -114,9 +114,9 @@ Status ExpandImpl(
|
|||
const int N_input,
|
||||
const void* input_data,
|
||||
void* output_data,
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& fdm_output_strides,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_view_strides) {
|
||||
const int rank = static_cast<int>(fdm_output_strides.count());
|
||||
const TArray<fast_divmod>& output_strides,
|
||||
const TArray<int64_t>& input_strides) {
|
||||
const int rank = static_cast<int>(output_strides.size_);
|
||||
if (rank == 1) {
|
||||
if (N_input == N_output) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_data, input_data, N_output * element_size, cudaMemcpyDeviceToDevice));
|
||||
|
|
@ -125,20 +125,18 @@ Status ExpandImpl(
|
|||
}
|
||||
} else if (rank == 2) {
|
||||
return Expand2D(element_size, N_output, input_data, output_data,
|
||||
fdm_output_strides.CpuSpan()[0],
|
||||
static_cast<int>(input_view_strides.CpuSpan()[0]),
|
||||
static_cast<int>(input_view_strides.CpuSpan()[1]));
|
||||
output_strides[0],
|
||||
static_cast<int>(input_strides[0]),
|
||||
static_cast<int>(input_strides[1]));
|
||||
}
|
||||
|
||||
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(N_output, GridDim::maxThreadsPerBlock));
|
||||
fdm_output_strides.CopyToGpu();
|
||||
input_view_strides.CopyToGpu();
|
||||
|
||||
#define EXPAND_ON(TYPE) \
|
||||
case sizeof(TYPE): \
|
||||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>( \
|
||||
rank, N_output, reinterpret_cast<const TYPE*>(input_data), reinterpret_cast<TYPE*>(output_data), \
|
||||
fdm_output_strides.GpuPtr(), input_view_strides.GpuPtr()); \
|
||||
output_strides, input_strides); \
|
||||
break
|
||||
|
||||
switch (element_size) {
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ Status ExpandImpl(
|
|||
const int N_input,
|
||||
const void* input_data,
|
||||
void* output_data,
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& fdm_output_strides,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_view_strides);
|
||||
const TArray<fast_divmod>& output_strides,
|
||||
const TArray<int64_t>& input_strides);
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ Status GatherElements::ComputeInternal(OpKernelContext* context) const {
|
|||
const auto* indices_tensor = context->Input<Tensor>(1);
|
||||
const auto& indices_shape = indices_tensor->Shape();
|
||||
const auto& indices_dims = indices_shape.GetDims();
|
||||
const int64_t indices_rank = static_cast<int64_t>(indices_dims.size());
|
||||
const int32_t indices_rank = static_cast<int32_t>(indices_dims.size());
|
||||
const int64_t indices_size = indices_shape.Size();
|
||||
|
||||
// Handle negative axis if any
|
||||
|
|
@ -51,13 +51,13 @@ Status GatherElements::ComputeInternal(OpKernelContext* context) const {
|
|||
return Status::OK();
|
||||
|
||||
TensorPitches input_strides(input_dims);
|
||||
CudaAsyncBuffer<int64_t> gpu_input_strides(this, input_strides);
|
||||
TArray<int64_t> gpu_input_strides(input_strides);
|
||||
|
||||
CudaAsyncBuffer<fast_divmod> fdm_indices_strides(this, indices_rank);
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_indices_strides.CpuSpan(), indices_dims));
|
||||
|
||||
ORT_RETURN_IF_ERROR(gpu_input_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_indices_strides.CopyToGpu());
|
||||
TArray<fast_divmod> fdm_indices_strides(indices_rank);
|
||||
TensorPitches indices_strides(indices_dims);
|
||||
for (auto i = 0; i < indices_rank; i++) {
|
||||
fdm_indices_strides[i] = fast_divmod(static_cast<int>(indices_strides[i]));
|
||||
}
|
||||
|
||||
size_t element_size = input_tensor->DataType()->Size();
|
||||
|
||||
|
|
@ -67,10 +67,10 @@ Status GatherElements::ComputeInternal(OpKernelContext* context) const {
|
|||
input_rank,
|
||||
input_tensor->DataRaw(),
|
||||
input_dims[axis],
|
||||
gpu_input_strides.GpuPtr(),
|
||||
gpu_input_strides,
|
||||
indices_data,
|
||||
indices_size,
|
||||
fdm_indices_strides.GpuPtr(),
|
||||
fdm_indices_strides,
|
||||
axis,
|
||||
output_tensor->MutableDataRaw(),
|
||||
element_size);
|
||||
|
|
@ -81,10 +81,10 @@ Status GatherElements::ComputeInternal(OpKernelContext* context) const {
|
|||
input_rank,
|
||||
input_tensor->DataRaw(),
|
||||
input_dims[axis],
|
||||
gpu_input_strides.GpuPtr(),
|
||||
gpu_input_strides,
|
||||
indices_data,
|
||||
indices_size,
|
||||
fdm_indices_strides.GpuPtr(),
|
||||
fdm_indices_strides,
|
||||
axis,
|
||||
output_tensor->MutableDataRaw(),
|
||||
element_size);
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ __global__ void _GatherElementsKernel(
|
|||
const int64_t rank,
|
||||
const T* input_data,
|
||||
const int64_t input_dim_along_axis,
|
||||
const int64_t* input_strides,
|
||||
const TArray<int64_t> input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
const fast_divmod* indices_strides,
|
||||
const TArray<fast_divmod> indices_strides,
|
||||
const int64_t axis,
|
||||
T* output_data) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(indices_index, indices_size);
|
||||
|
|
@ -43,10 +43,10 @@ void GatherElementsImpl(
|
|||
const int64_t rank,
|
||||
const void* input_data,
|
||||
const int64_t input_dim_along_axis,
|
||||
const int64_t* input_strides,
|
||||
const TArray<int64_t>& input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
const fast_divmod* indices_strides,
|
||||
const TArray<fast_divmod>& indices_strides,
|
||||
const int64_t axis,
|
||||
void* output_data,
|
||||
size_t element_size) {
|
||||
|
|
@ -95,10 +95,10 @@ template void GatherElementsImpl<int32_t>(
|
|||
const int64_t rank,
|
||||
const void* input_data,
|
||||
const int64_t input_dim_along_axis,
|
||||
const int64_t* input_strides,
|
||||
const TArray<int64_t>& input_strides,
|
||||
const int32_t* indices_data,
|
||||
const int64_t indices_size,
|
||||
const fast_divmod* indices_strides,
|
||||
const TArray<fast_divmod>& indices_strides,
|
||||
const int64_t axis,
|
||||
void* output_data,
|
||||
size_t element_size);
|
||||
|
|
@ -107,10 +107,10 @@ template void GatherElementsImpl<int64_t>(
|
|||
const int64_t rank,
|
||||
const void* input_data,
|
||||
const int64_t input_dim_along_axis,
|
||||
const int64_t* input_strides,
|
||||
const TArray<int64_t>& input_strides,
|
||||
const int64_t* indices_data,
|
||||
const int64_t indices_size,
|
||||
const fast_divmod* indices_strides,
|
||||
const TArray<fast_divmod>& indices_strides,
|
||||
const int64_t axis,
|
||||
void* output_data,
|
||||
size_t element_size);
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ void GatherElementsImpl(
|
|||
const int64_t rank, // both inputs have same rank and this is validated in the main Compute
|
||||
const void* input_data,
|
||||
const int64_t input_dim_along_axis,
|
||||
const int64_t* input_strides,
|
||||
const TArray<int64_t>& input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
const fast_divmod* indices_strides,
|
||||
const TArray<fast_divmod>& indices_strides,
|
||||
const int64_t axis,
|
||||
void* output_data,
|
||||
size_t element_size);
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ __global__ void NonZeroCountEachBlockKernel(const InputT* x, int64_t x_size, int
|
|||
|
||||
template <typename InputT, int THREADS_PER_BLOCK>
|
||||
__global__ void NonZeroOutputPositionsKernel(
|
||||
const InputT* x, int64_t x_size, int x_rank, const fast_divmod* x_strides,
|
||||
const InputT* x, int64_t x_size, int x_rank, const TArray<fast_divmod> x_strides,
|
||||
const int* prefix_counts, int nonzero_elements, int64_t* results) {
|
||||
typedef cub::BlockScan<int, THREADS_PER_BLOCK> BlockScanT;
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
|
|
@ -78,7 +78,7 @@ cudaError_t NonZeroCountEachBlock(const InputT* x, int64_t x_size, int* count_in
|
|||
|
||||
template <typename InputT>
|
||||
cudaError_t NonZeroOutputPositions(
|
||||
const InputT* x, int64_t x_size, int x_rank, const fast_divmod* x_strides,
|
||||
const InputT* x, int64_t x_size, int x_rank, const TArray<fast_divmod>& x_strides,
|
||||
const int* prefix_counts, int nonzero_elements, int64_t* results) {
|
||||
int num_blocks = NonZeroCalcBlockCount(x_size);
|
||||
NonZeroOutputPositionsKernel<InputT, NONZERO_THREADS_PER_BLOCK><<<num_blocks, NONZERO_THREADS_PER_BLOCK>>>(
|
||||
|
|
@ -94,12 +94,12 @@ template cudaError_t NonZeroCountEachBlock(const int32_t*, int64_t, int*);
|
|||
template cudaError_t NonZeroCountEachBlock(const float*, int64_t, int*);
|
||||
template cudaError_t NonZeroCountEachBlock(const half*, int64_t, int*);
|
||||
|
||||
template cudaError_t NonZeroOutputPositions(const bool*, int64_t, int, const fast_divmod*, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const uint8_t*, int64_t, int, const fast_divmod*, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const int64_t*, int64_t, int, const fast_divmod*, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const int32_t*, int64_t, int, const fast_divmod*, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const float*, int64_t, int, const fast_divmod*, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const half*, int64_t, int, const fast_divmod*, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const bool*, int64_t, int, const TArray<fast_divmod>&, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const uint8_t*, int64_t, int, const TArray<fast_divmod>&, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const int64_t*, int64_t, int, const TArray<fast_divmod>&, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const int32_t*, int64_t, int, const TArray<fast_divmod>&, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const float*, int64_t, int, const TArray<fast_divmod>&, const int*, int, int64_t*);
|
||||
template cudaError_t NonZeroOutputPositions(const half*, int64_t, int, const TArray<fast_divmod>&, const int*, int, int64_t*);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ cudaError_t NonZeroCountEachBlock(const InputT* x, int64_t x_size, int* counts_i
|
|||
// output nonzero positions using input x and prefix_counts for each blocks
|
||||
template<typename InputT>
|
||||
cudaError_t NonZeroOutputPositions(
|
||||
const InputT *x, int64_t x_size, int x_rank, const fast_divmod* x_strides,
|
||||
const InputT *x, int64_t x_size, int x_rank, const TArray<fast_divmod>& x_strides,
|
||||
const int* prefix_counts, int nonzero_elements, int64_t* results);
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -70,14 +70,16 @@ Status NonZero<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
&nonzero_elements, prefix_counts + number_of_blocks - 1,
|
||||
sizeof(int), cudaMemcpyDeviceToHost));
|
||||
|
||||
CudaAsyncBuffer<fast_divmod> fdm_x_strides(this, x_rank);
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_x_strides.CpuSpan(), x_dims));
|
||||
ORT_RETURN_IF_ERROR(fdm_x_strides.CopyToGpu());
|
||||
TArray<fast_divmod> fdm_x_strides(x_rank);
|
||||
TensorPitches x_strides(x_dims);
|
||||
for (auto i = 0; i < x_rank; i++) {
|
||||
fdm_x_strides[i] = fast_divmod(static_cast<int>(x_strides[i]));
|
||||
}
|
||||
|
||||
auto* output_tensor = context->Output(0, {x_rank, nonzero_elements});
|
||||
ORT_ENFORCE(output_tensor, "failed to get first output!");
|
||||
CUDA_RETURN_IF_ERROR(NonZeroOutputPositions(
|
||||
x_data, x_size, x_rank, fdm_x_strides.GpuPtr(),
|
||||
x_data, x_size, x_rank, fdm_x_strides,
|
||||
prefix_counts, nonzero_elements, output_tensor->template MutableData<int64_t>()));
|
||||
} else {
|
||||
context->Output(0, {x_rank, nonzero_elements});
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const auto& input_tensor = *ctx->Input<Tensor>(0);
|
||||
auto const& input_shape = input_tensor.Shape();
|
||||
auto dimension_count = input_shape.NumDimensions();
|
||||
int32_t dimension_count = static_cast<int32_t>(input_shape.NumDimensions());
|
||||
|
||||
const std::vector<int64_t>* p_pads = &pads_;
|
||||
const std::vector<int64_t>* p_slices = &slices_;
|
||||
|
|
@ -94,23 +94,20 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
p_slices = &slices;
|
||||
}
|
||||
|
||||
CudaAsyncBuffer<int64_t> input_dims(this, input_shape.GetDims());
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, dimension_count);
|
||||
CudaAsyncBuffer<int64_t> lower_pads(this, dimension_count);
|
||||
CudaAsyncBuffer<int64_t> upper_pads(this, dimension_count);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, dimension_count);
|
||||
TensorPitches input_pitches(input_shape.GetDims());
|
||||
TArray<int64_t> input_dims(input_shape.GetDims());
|
||||
TArray<int64_t> input_strides(input_pitches);
|
||||
|
||||
TensorPitches::Calculate(input_strides.CpuSpan(), input_shape.GetDims());
|
||||
std::vector<int64_t> output_dims(input_shape.GetDims());
|
||||
ORT_ENFORCE(dimension_count * 2 == p_pads->size(), "'pads' attribute has wrong number of values");
|
||||
|
||||
// Calculate output dimensions, and handle any negative padding
|
||||
auto lower_pads_span = lower_pads.CpuSpan();
|
||||
auto upper_pads_span = upper_pads.CpuSpan();
|
||||
for (size_t i = 0; i < dimension_count; i++) {
|
||||
lower_pads_span[i] = (*p_pads)[i] + (*p_slices)[i];
|
||||
upper_pads_span[i] = (*p_pads)[i + dimension_count] + (*p_slices)[i + dimension_count];
|
||||
output_dims[i] += lower_pads_span[i] + upper_pads_span[i];
|
||||
TArray<int64_t> lower_pads(dimension_count);
|
||||
TArray<int64_t> upper_pads(dimension_count);
|
||||
for (auto i = 0; i < dimension_count; i++) {
|
||||
lower_pads[i] = (*p_pads)[i] + (*p_slices)[i];
|
||||
upper_pads[i] = (*p_pads)[i + dimension_count] + (*p_slices)[i + dimension_count];
|
||||
output_dims[i] += lower_pads[i] + upper_pads[i];
|
||||
}
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
|
|
@ -130,23 +127,22 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));
|
||||
ORT_RETURN_IF_ERROR(input_dims.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(input_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(lower_pads.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(upper_pads.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu());
|
||||
TArray<fast_divmod> fdm_output_strides(dimension_count);
|
||||
TensorPitches output_strides(output_dims);
|
||||
for (auto i = 0; i < dimension_count; i++) {
|
||||
fdm_output_strides[i] = fast_divmod(static_cast<int>(output_strides[i]));
|
||||
}
|
||||
|
||||
PadImpl(
|
||||
dimension_count,
|
||||
input_dims.GpuPtr(),
|
||||
input_strides.GpuPtr(),
|
||||
lower_pads.GpuPtr(),
|
||||
upper_pads.GpuPtr(),
|
||||
input_dims,
|
||||
input_strides,
|
||||
lower_pads,
|
||||
upper_pads,
|
||||
value,
|
||||
static_cast<int>(mode_),
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(input_tensor.template Data<T>()),
|
||||
fdm_output_strides.GpuPtr(),
|
||||
fdm_output_strides,
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(output_tensor.template MutableData<T>()),
|
||||
output_tensor.Shape().Size());
|
||||
|
||||
|
|
|
|||
|
|
@ -17,13 +17,13 @@ enum class PadMode : int {
|
|||
template <typename T, int pad_mode>
|
||||
__global__ void _PadKernel(
|
||||
const size_t shape_rank,
|
||||
const int64_t* input_dims,
|
||||
const int64_t* input_strides,
|
||||
const int64_t* lower_pads,
|
||||
const int64_t* upper_pads,
|
||||
const TArray<int64_t> input_dims,
|
||||
const TArray<int64_t> input_strides,
|
||||
const TArray<int64_t> lower_pads,
|
||||
const TArray<int64_t> upper_pads,
|
||||
const T pad_value,
|
||||
const T* input_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod> fdm_output_strides,
|
||||
T* output_data,
|
||||
const size_t N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
|
|
@ -70,14 +70,14 @@ __global__ void _PadKernel(
|
|||
template <typename T>
|
||||
void PadImpl(
|
||||
const size_t shape_rank,
|
||||
const int64_t* input_dims,
|
||||
const int64_t* input_strides,
|
||||
const int64_t* lower_pads,
|
||||
const int64_t* upper_pads,
|
||||
const TArray<int64_t>& input_dims,
|
||||
const TArray<int64_t>& input_strides,
|
||||
const TArray<int64_t>& lower_pads,
|
||||
const TArray<int64_t>& upper_pads,
|
||||
const T pad_value,
|
||||
const int pad_mode,
|
||||
const T* input_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
T* output_data,
|
||||
const size_t N) {
|
||||
if (N == 0) // special case where there's a dim value of 0 in the output shape
|
||||
|
|
@ -104,7 +104,7 @@ void PadImpl(
|
|||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void PadImpl<T>(const size_t shape_rank, const int64_t* input_dims, const int64_t* input_strides, const int64_t* lower_pads, const int64_t* upper_pads, const T pad_value, const int pad_mode, const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, const size_t N);
|
||||
template void PadImpl<T>(const size_t shape_rank, const TArray<int64_t>& input_dims, const TArray<int64_t>& input_strides, const TArray<int64_t>& lower_pads, const TArray<int64_t>& upper_pads, const T pad_value, const int pad_mode, const T* input_data, const TArray<fast_divmod>& fdm_output_strides, T* output_data, const size_t N);
|
||||
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(double)
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ namespace cuda {
|
|||
template <typename T>
|
||||
void PadImpl(
|
||||
const size_t shape_rank,
|
||||
const int64_t* input_dims,
|
||||
const int64_t* input_strides,
|
||||
const int64_t* lower_pads,
|
||||
const int64_t* upper_pads,
|
||||
const TArray<int64_t>& input_dims,
|
||||
const TArray<int64_t>& input_strides,
|
||||
const TArray<int64_t>& lower_pads,
|
||||
const TArray<int64_t>& upper_pads,
|
||||
const T pad_value,
|
||||
const int pad_mode,
|
||||
const T* input_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
T* output_data,
|
||||
const size_t N);
|
||||
|
||||
|
|
|
|||
|
|
@ -172,10 +172,10 @@ __global__ void _ResizeNearestMappingKernel2D(
|
|||
template <typename T>
|
||||
__global__ void _ResizeNearestMappingKernel(
|
||||
const size_t rank,
|
||||
const int64_t* input_shape,
|
||||
const int64_t* output_shape,
|
||||
const float* scales,
|
||||
const float* roi,
|
||||
const TArray<int64_t> input_shape,
|
||||
const TArray<int64_t> output_shape,
|
||||
const TArray<float> scales,
|
||||
const TArray<float> roi,
|
||||
const size_t total_dim_sum,
|
||||
bool extrapolation_enabled,
|
||||
CudaFunctionOriginalCoordinate transform_coordinate,
|
||||
|
|
@ -230,8 +230,8 @@ __global__ void _ResizeNearestKernel2D(
|
|||
template <typename T>
|
||||
__global__ void _ResizeNearestKernel(
|
||||
const int rank,
|
||||
const int64_t* input_strides,
|
||||
const fast_divmod* output_div_pitches,
|
||||
const TArray<int64_t> input_strides,
|
||||
const TArray<fast_divmod> output_div_pitches,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N,
|
||||
|
|
@ -447,12 +447,12 @@ size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode,
|
|||
template <typename T>
|
||||
void ResizeNearestImpl(
|
||||
const int rank,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_shape,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& output_shape,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_strides,
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& output_div_pitches,
|
||||
CudaKernel::CudaAsyncBuffer<float>& scales_vals,
|
||||
CudaKernel::CudaAsyncBuffer<float>& roi_vals,
|
||||
TArray<int64_t>& input_shape,
|
||||
TArray<int64_t>& output_shape,
|
||||
TArray<int64_t>& input_strides,
|
||||
TArray<fast_divmod>& output_div_pitches,
|
||||
TArray<float>& scales_vals,
|
||||
TArray<float>& roi_vals,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N,
|
||||
|
|
@ -467,34 +467,34 @@ void ResizeNearestImpl(
|
|||
|
||||
bool could2d = rank >= 2 &&
|
||||
transform_coordinate != GetDeviceOriginalCoordinateFunc(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE) &&
|
||||
std::all_of(scales_vals.CpuPtr(), scales_vals.CpuPtr() + (rank - 2), [](float v) { return v == 1.0; });
|
||||
std::all_of(scales_vals.data_, scales_vals.data_ + (rank - 2), [](float v) { return v == 1.0; });
|
||||
if (could2d) {
|
||||
int64_t output_height = output_shape.CpuPtr()[rank - 2];
|
||||
int64_t output_width = output_shape.CpuPtr()[rank - 1];
|
||||
fast_divmod div_output_image = (rank > 2) ? output_div_pitches.CpuPtr()[rank - 3] : fast_divmod(output_height * output_width);
|
||||
int64_t output_height = output_shape[rank - 2];
|
||||
int64_t output_width = output_shape[rank - 1];
|
||||
fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] : fast_divmod(output_height * output_width);
|
||||
int blocksPerDimsMappingGrid = (int)(ceil((output_height + output_width) / 32.0));
|
||||
|
||||
_ResizeNearestMappingKernel2D<T><<<blocksPerDimsMappingGrid, 32, 0>>>(
|
||||
input_shape.CpuPtr()[rank - 2], input_shape.CpuPtr()[rank - 1],
|
||||
input_shape[rank - 2], input_shape[rank - 1],
|
||||
output_height, output_width,
|
||||
scales_vals.CpuPtr()[rank - 2], scales_vals.CpuPtr()[rank - 1],
|
||||
roi_vals.CpuPtr()[rank - 2], roi_vals.CpuPtr()[rank - 2 + rank],
|
||||
roi_vals.CpuPtr()[rank - 1], roi_vals.CpuPtr()[rank - 1 + rank],
|
||||
scales_vals[rank - 2], scales_vals[rank - 1],
|
||||
roi_vals[rank - 2], roi_vals[rank - 2 + rank],
|
||||
roi_vals[rank - 1], roi_vals[rank - 1 + rank],
|
||||
extrapolation_enabled, transform_coordinate, calc_nearest_pixel,
|
||||
dims_mapping);
|
||||
if (extrapolation_enabled) {
|
||||
_ResizeNearestKernel2D<T, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_height, output_width,
|
||||
input_shape.CpuPtr()[rank - 2] * input_shape.CpuPtr()[rank - 1], input_shape.CpuPtr()[rank - 1],
|
||||
div_output_image, output_div_pitches.CpuPtr()[rank - 2],
|
||||
input_shape[rank - 2] * input_shape[rank - 1], input_shape[rank - 1],
|
||||
div_output_image, output_div_pitches[rank - 2],
|
||||
input_data, output_data, N,
|
||||
extrapolation_value,
|
||||
dims_mapping);
|
||||
} else {
|
||||
_ResizeNearestKernel2D<T, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_height, output_width,
|
||||
input_shape.CpuPtr()[rank - 2] * input_shape.CpuPtr()[rank - 1], input_shape.CpuPtr()[rank - 1],
|
||||
div_output_image, output_div_pitches.CpuPtr()[rank - 2],
|
||||
input_shape[rank - 2] * input_shape[rank - 1], input_shape[rank - 1],
|
||||
div_output_image, output_div_pitches[rank - 2],
|
||||
input_data, output_data, N,
|
||||
extrapolation_value,
|
||||
dims_mapping);
|
||||
|
|
@ -502,23 +502,17 @@ void ResizeNearestImpl(
|
|||
return;
|
||||
}
|
||||
|
||||
int64_t total_dim_sum = std::accumulate(output_shape.CpuPtr(), output_shape.CpuPtr() + rank, 0);
|
||||
int64_t total_dim_sum = std::accumulate(output_shape.data_, output_shape.data_ + rank, 0);
|
||||
int blocksPerDimsMappingGrid = (int)(ceil(static_cast<double>(total_dim_sum) / 32));
|
||||
input_shape.CopyToGpu();
|
||||
output_shape.CopyToGpu();
|
||||
roi_vals.CopyToGpu();
|
||||
scales_vals.CopyToGpu();
|
||||
input_strides.CopyToGpu();
|
||||
output_div_pitches.CopyToGpu();
|
||||
_ResizeNearestMappingKernel<T><<<blocksPerDimsMappingGrid, 32, 0>>>(
|
||||
rank, input_shape.GpuPtr(), output_shape.GpuPtr(),
|
||||
scales_vals.GpuPtr(), roi_vals.GpuPtr(),
|
||||
rank, input_shape, output_shape,
|
||||
scales_vals, roi_vals,
|
||||
total_dim_sum, extrapolation_enabled,
|
||||
transform_coordinate, calc_nearest_pixel,
|
||||
reinterpret_cast<int64_t*>(dims_mapping),
|
||||
reinterpret_cast<NearestMappingInfo*>(reinterpret_cast<int64_t*>(dims_mapping) + rank));
|
||||
_ResizeNearestKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, input_strides.GpuPtr(), output_div_pitches.GpuPtr(),
|
||||
rank, input_strides, output_div_pitches,
|
||||
input_data, output_data, N,
|
||||
extrapolation_value,
|
||||
reinterpret_cast<const int64_t*>(dims_mapping),
|
||||
|
|
@ -530,12 +524,12 @@ template <typename T>
|
|||
void ResizeImpl(
|
||||
const UpsampleMode upsample_mode,
|
||||
const int rank,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_shape,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& output_shape,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_strides,
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& output_div_pitches,
|
||||
CudaKernel::CudaAsyncBuffer<float>& scales_vals,
|
||||
CudaKernel::CudaAsyncBuffer<float>& roi_vals,
|
||||
TArray<int64_t>& input_shape,
|
||||
TArray<int64_t>& output_shape,
|
||||
TArray<int64_t>& input_strides,
|
||||
TArray<fast_divmod>& output_div_pitches,
|
||||
TArray<float>& scales_vals,
|
||||
TArray<float>& roi_vals,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N,
|
||||
|
|
@ -546,7 +540,7 @@ void ResizeImpl(
|
|||
ResizeCoordinateTransformationMode coordinate_transform_mode,
|
||||
ResizeNearestMode nearest_mode,
|
||||
void* dims_mapping) {
|
||||
bool isSame = std::all_of(scales_vals.CpuPtr(), scales_vals.CpuPtr() + rank, [](float v) { return v == 1.0f; }) &&
|
||||
bool isSame = std::all_of(scales_vals.data_, scales_vals.data_ + rank, [](float v) { return v == 1.0f; }) &&
|
||||
(coordinate_transform_mode != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE);
|
||||
if (isSame) {
|
||||
cudaMemcpyAsync(output_data, input_data, N * sizeof(T), cudaMemcpyDeviceToDevice);
|
||||
|
|
@ -567,41 +561,41 @@ void ResizeImpl(
|
|||
}
|
||||
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
fast_divmod div_output_image = (rank > 2) ? output_div_pitches.CpuPtr()[rank - 3] : fast_divmod(gsl::narrow_cast<int>(N));
|
||||
int64_t output_height = output_shape.CpuPtr()[rank - 2];
|
||||
int64_t output_width = output_shape.CpuPtr()[rank - 1];
|
||||
fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] : fast_divmod(gsl::narrow_cast<int>(N));
|
||||
int64_t output_height = output_shape[rank - 2];
|
||||
int64_t output_width = output_shape[rank - 1];
|
||||
int blocksPerDimsMappingGrid = (int)(ceil((output_height + output_width) / 32.0));
|
||||
switch (upsample_mode) {
|
||||
case UpsampleMode::LINEAR:
|
||||
_ResizeBilinearCoordinateMapping<T><<<blocksPerDimsMappingGrid, 32, 0>>>(
|
||||
input_shape.CpuPtr()[rank - 2], input_shape.CpuPtr()[rank - 1],
|
||||
input_shape[rank - 2], input_shape[rank - 1],
|
||||
output_height, output_width,
|
||||
scales_vals.CpuPtr()[rank - 2], scales_vals.CpuPtr()[rank - 1],
|
||||
roi_vals.CpuPtr()[rank - 2], roi_vals.CpuPtr()[rank - 2 + rank],
|
||||
roi_vals.CpuPtr()[rank - 1], roi_vals.CpuPtr()[rank - 1 + rank],
|
||||
scales_vals[rank - 2], scales_vals[rank - 1],
|
||||
roi_vals[rank - 2], roi_vals[rank - 2 + rank],
|
||||
roi_vals[rank - 1], roi_vals[rank - 1 + rank],
|
||||
output_height + output_width, extrapolation_enabled, transform_coordinate,
|
||||
reinterpret_cast<BilinearMappingInfo*>(dims_mapping));
|
||||
_ResizeBilinearKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
input_shape.CpuPtr()[rank - 2], input_shape.CpuPtr()[rank - 1],
|
||||
input_shape[rank - 2], input_shape[rank - 1],
|
||||
output_height, output_width,
|
||||
output_div_pitches.CpuPtr()[rank - 2], div_output_image,
|
||||
output_div_pitches[rank - 2], div_output_image,
|
||||
input_data, output_data, N, extrapolation_value,
|
||||
reinterpret_cast<BilinearMappingInfo*>(dims_mapping));
|
||||
return;
|
||||
case UpsampleMode::CUBIC:
|
||||
_ResizeCubicCoordinateMapping<T><<<blocksPerDimsMappingGrid, 32, 0>>>(
|
||||
input_shape.CpuPtr()[rank - 2], input_shape.CpuPtr()[rank - 1],
|
||||
input_shape[rank - 2], input_shape[rank - 1],
|
||||
output_height, output_width,
|
||||
scales_vals.CpuPtr()[rank - 2], scales_vals.CpuPtr()[rank - 1],
|
||||
roi_vals.CpuPtr()[rank - 2], roi_vals.CpuPtr()[rank - 2 + rank],
|
||||
roi_vals.CpuPtr()[rank - 1], roi_vals.CpuPtr()[rank - 1 + rank],
|
||||
scales_vals[rank - 2], scales_vals[rank - 1],
|
||||
roi_vals[rank - 2], roi_vals[rank - 2 + rank],
|
||||
roi_vals[rank - 1], roi_vals[rank - 1 + rank],
|
||||
output_height + output_width, extrapolation_enabled,
|
||||
cubic_coeff_a, exclude_outside, transform_coordinate,
|
||||
reinterpret_cast<CubicMappingInfo*>(dims_mapping));
|
||||
_ResizeBiCubicKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
input_shape.CpuPtr()[rank - 2], input_shape.CpuPtr()[rank - 1],
|
||||
input_shape[rank - 2], input_shape[rank - 1],
|
||||
output_height, output_width,
|
||||
output_div_pitches.CpuPtr()[rank - 2], div_output_image,
|
||||
output_div_pitches[rank - 2], div_output_image,
|
||||
input_data, output_data, N, extrapolation_value,
|
||||
reinterpret_cast<CubicMappingInfo*>(dims_mapping));
|
||||
return;
|
||||
|
|
@ -612,12 +606,12 @@ void ResizeImpl(
|
|||
template void ResizeImpl<T>( \
|
||||
const UpsampleMode upsample_mode, \
|
||||
const int rank, \
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_shape, \
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& output_shape, \
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_strides, \
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& output_div_pitches, \
|
||||
CudaKernel::CudaAsyncBuffer<float>& scales_vals, \
|
||||
CudaKernel::CudaAsyncBuffer<float>& roi_vals, \
|
||||
TArray<int64_t>& input_shape, \
|
||||
TArray<int64_t>& output_shape, \
|
||||
TArray<int64_t>& input_strides, \
|
||||
TArray<fast_divmod>& output_div_pitches, \
|
||||
TArray<float>& scales_vals, \
|
||||
TArray<float>& roi_vals, \
|
||||
const T* input_data, \
|
||||
T* output_data, \
|
||||
const size_t N, \
|
||||
|
|
|
|||
|
|
@ -18,12 +18,12 @@ template <typename T>
|
|||
void ResizeImpl(
|
||||
const onnxruntime::UpsampleMode upsample_mode,
|
||||
const int rank,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_shape,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& output_shape,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& input_strides,
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& output_div_pitches,
|
||||
CudaKernel::CudaAsyncBuffer<float>& scales_vals,
|
||||
CudaKernel::CudaAsyncBuffer<float>& roi,
|
||||
TArray<int64_t>& input_shape,
|
||||
TArray<int64_t>& output_shape,
|
||||
TArray<int64_t>& input_strides,
|
||||
TArray<fast_divmod>& output_div_pitches,
|
||||
TArray<float>& scales_vals,
|
||||
TArray<float>& roi,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N,
|
||||
|
|
|
|||
|
|
@ -120,13 +120,16 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
|
|||
int rank = (int)input_dims.size();
|
||||
auto* output_tensor = context->Output(0, input_data_shape);
|
||||
|
||||
CudaAsyncBuffer<int64_t> buffer_input_dims(this, input_dims);
|
||||
TArray<int64_t> buffer_input_dims(input_dims);
|
||||
TensorPitches input_strides(input_dims);
|
||||
CudaAsyncBuffer<int64_t> buffer_input_strides(this, input_strides);
|
||||
TArray<int64_t> buffer_input_strides(input_strides);
|
||||
|
||||
CudaAsyncBuffer<int64_t> buffer_indices_dims(this, indices_dims);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_indices_strides(this, rank);
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_indices_strides.CpuSpan(), indices_dims));
|
||||
TArray<int64_t> buffer_indices_dims(indices_dims);
|
||||
TArray<fast_divmod> fdm_indices_strides(rank);
|
||||
TensorPitches indices_strides(indices_dims);
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
fdm_indices_strides[i] = fast_divmod(static_cast<int>(indices_strides[i]));
|
||||
}
|
||||
|
||||
MLDataType Tin_type = indices_tensor->DataType();
|
||||
MLDataType T_type = data_tensor->DataType();
|
||||
|
|
|
|||
|
|
@ -38,12 +38,12 @@ template <typename T, typename Tin>
|
|||
__global__ void _ScatterElementsKernel(
|
||||
const int rank,
|
||||
const T* input_data,
|
||||
const int64_t* input_dims,
|
||||
const int64_t* input_strides,
|
||||
const TArray<int64_t> input_dims,
|
||||
const TArray<int64_t> input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
const int64_t* indices_dims,
|
||||
const fast_divmod* indices_strides,
|
||||
const TArray<int64_t> indices_dims,
|
||||
const TArray<fast_divmod> indices_strides,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data) {
|
||||
|
|
@ -166,12 +166,12 @@ Status ScatterElementsImpl(
|
|||
const int rank,
|
||||
const T* input_data,
|
||||
const int64_t input_size,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_input_dims,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_input_strides,
|
||||
TArray<int64_t>& buffer_input_dims,
|
||||
TArray<int64_t>& buffer_input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_indices_dims,
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& fdm_indices_strides,
|
||||
TArray<int64_t>& buffer_indices_dims,
|
||||
TArray<fast_divmod>& fdm_indices_strides,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data) {
|
||||
|
|
@ -183,20 +183,16 @@ Status ScatterElementsImpl(
|
|||
std::vector<int64_t> eff_input_dims;
|
||||
std::vector<int64_t> eff_indices_dims;
|
||||
int new_axis = CompactInputIndicesDims(
|
||||
rank, axis, buffer_input_dims.CpuPtr(), buffer_indices_dims.CpuPtr(), eff_input_dims, eff_indices_dims);
|
||||
rank, axis, buffer_input_dims.data_, buffer_indices_dims.data_, eff_input_dims, eff_indices_dims);
|
||||
if (eff_input_dims.size() == 2) {
|
||||
return ScatterElementsImpl2D(
|
||||
input_data, eff_input_dims, indices_data, indices_size, eff_indices_dims, updates, new_axis, output_data);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(buffer_input_dims.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(buffer_input_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(buffer_indices_dims.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_indices_strides.CopyToGpu());
|
||||
int blocksPerGrid = gsl::narrow_cast<int>(CeilDiv(indices_size, GridDim::maxThreadsPerBlock));
|
||||
_ScatterElementsKernel<T, Tin><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, input_data, buffer_input_dims.GpuPtr(), buffer_input_strides.GpuPtr(),
|
||||
indices_data, indices_size, buffer_indices_dims.GpuPtr(), fdm_indices_strides.GpuPtr(),
|
||||
rank, input_data, buffer_input_dims, buffer_input_strides,
|
||||
indices_data, indices_size, buffer_indices_dims, fdm_indices_strides,
|
||||
updates, axis, output_data);
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
@ -207,12 +203,12 @@ Status ScatterElementsImpl(
|
|||
const int rank, \
|
||||
const T* input_data, \
|
||||
const int64_t input_size, \
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_input_dims, \
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_input_strides, \
|
||||
TArray<int64_t>& buffer_input_dims, \
|
||||
TArray<int64_t>& buffer_input_strides, \
|
||||
const TIndex* indices_data, \
|
||||
const int64_t indices_size, \
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_indices_dims, \
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& indices_strides, \
|
||||
TArray<int64_t>& buffer_indices_dims, \
|
||||
TArray<fast_divmod>& indices_strides, \
|
||||
const T* updates, \
|
||||
const int axis, \
|
||||
T* output_data)
|
||||
|
|
|
|||
|
|
@ -14,12 +14,12 @@ Status ScatterElementsImpl(
|
|||
const int rank,
|
||||
const T* input_data,
|
||||
const int64_t input_size,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_input_dims,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_input_strides,
|
||||
TArray<int64_t>& buffer_input_dims,
|
||||
TArray<int64_t>& buffer_input_strides,
|
||||
const Tin* indices_data,
|
||||
const int64_t indices_size,
|
||||
CudaKernel::CudaAsyncBuffer<int64_t>& buffer_indices_dims,
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod>& indices_strides,
|
||||
TArray<int64_t>& buffer_indices_dims,
|
||||
TArray<fast_divmod>& indices_strides,
|
||||
const T* updates,
|
||||
const int axis,
|
||||
T* output_data);
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ template <typename T>
|
|||
Status Tile<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
auto& input_tensor = *ctx->Input<Tensor>(0);
|
||||
auto& repeats_tensor = *ctx->Input<Tensor>(1);
|
||||
size_t rank = input_tensor.Shape().NumDimensions();
|
||||
int32_t rank = static_cast<int32_t>(input_tensor.Shape().NumDimensions());
|
||||
|
||||
if (repeats_tensor.Shape().NumDimensions() != 1)
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional");
|
||||
|
|
@ -36,35 +36,35 @@ Status Tile<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
auto* repeats = repeats_tensor.template Data<int64_t>();
|
||||
const auto& input_shape = input_tensor.Shape().GetDims();
|
||||
std::vector<int64_t> output_dims(input_shape);
|
||||
for (size_t axis = 0; axis < rank; axis++)
|
||||
for (auto axis = 0; axis < rank; axis++)
|
||||
output_dims[axis] *= repeats[axis];
|
||||
TensorShape outputShape(output_dims);
|
||||
auto& output_tensor = *ctx->Output(0, outputShape);
|
||||
|
||||
T* output_data = output_tensor.template MutableData<T>();
|
||||
const T* input_data = input_tensor.template Data<T>();
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, rank);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_input_shape(this, rank);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, rank);
|
||||
|
||||
ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_shape));
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));
|
||||
TensorPitches input_pitches(input_shape);
|
||||
TArray<int64_t> input_strides(input_pitches);
|
||||
|
||||
auto fdm_input_shape_span = fdm_input_shape.CpuSpan();
|
||||
for (size_t i = 0; i < input_shape.size(); ++i)
|
||||
fdm_input_shape_span[i] = fast_divmod(gsl::narrow_cast<int>(input_shape[i]));
|
||||
TArray<fast_divmod> fdm_input_shape(rank);
|
||||
for (int32_t i = 0; i < input_shape.size(); ++i) {
|
||||
fdm_input_shape[i] = fast_divmod(gsl::narrow_cast<int>(input_shape[i]));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(fdm_input_shape.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(input_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu());
|
||||
TArray<fast_divmod> fdm_output_strides(rank);
|
||||
TensorPitches output_pitches(output_dims);
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
fdm_output_strides[i] = fast_divmod(static_cast<int>(output_pitches[i]));
|
||||
}
|
||||
|
||||
if (output_tensor.Shape().Size() > 0) {
|
||||
TileImpl(
|
||||
rank,
|
||||
fdm_input_shape.GpuPtr(),
|
||||
input_strides.GpuPtr(),
|
||||
fdm_input_shape,
|
||||
input_strides,
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(input_data),
|
||||
fdm_output_strides.GpuPtr(),
|
||||
fdm_output_strides,
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(output_data),
|
||||
output_tensor.Shape().Size());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ namespace cuda {
|
|||
template <typename T>
|
||||
__global__ void _TileKernel(
|
||||
const size_t shape_rank,
|
||||
const fast_divmod* fdm_input_shape,
|
||||
const int64_t* input_strides,
|
||||
const TArray<fast_divmod> fdm_input_shape,
|
||||
const TArray<int64_t> input_strides,
|
||||
const T* input_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod> fdm_output_strides,
|
||||
T* output_data,
|
||||
const CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
|
|
@ -33,10 +33,10 @@ __global__ void _TileKernel(
|
|||
template <typename T>
|
||||
void TileImpl(
|
||||
const size_t shape_rank,
|
||||
const fast_divmod* fdm_input_shape,
|
||||
const int64_t* input_stride,
|
||||
const TArray<fast_divmod>& fdm_input_shape,
|
||||
const TArray<int64_t>& input_stride,
|
||||
const T* input_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
T* output_data,
|
||||
const size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
|
@ -46,7 +46,7 @@ void TileImpl(
|
|||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void TileImpl<T>(const size_t shape_rank, const fast_divmod* fdm_input_shape, const int64_t* input_stride, const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, const size_t N);
|
||||
template void TileImpl<T>(const size_t shape_rank, const TArray<fast_divmod>& fdm_input_shape, const TArray<int64_t>& input_stride, const T* input_data, const TArray<fast_divmod>& fdm_output_strides, T* output_data, const size_t N);
|
||||
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(double)
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ namespace cuda {
|
|||
template <typename T>
|
||||
void TileImpl(
|
||||
const size_t shape_rank,
|
||||
const fast_divmod* input_shape,
|
||||
const int64_t* input_strides,
|
||||
const TArray<fast_divmod>& input_shape,
|
||||
const TArray<int64_t>& input_strides,
|
||||
const T* input_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
T* output_data,
|
||||
const size_t N);
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
|
|||
const std::vector<int64_t>& output_dims) const {
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const std::vector<int64_t>& X_dims = X->Shape().GetDims();
|
||||
auto rank = X_dims.size();
|
||||
int32_t rank = static_cast<int32_t>(X_dims.size());
|
||||
|
||||
ORT_ENFORCE(output_dims.size() == rank, "Rank of input and output tensor should be same.");
|
||||
if (rank == 0)
|
||||
|
|
@ -55,24 +55,21 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
|
|||
|
||||
// kernel
|
||||
TensorPitches input_pitches(X_dims);
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, rank);
|
||||
gsl::span<int64_t> input_stride_span = input_strides.CpuSpan();
|
||||
TArray<int64_t> input_strides(input_pitches);
|
||||
|
||||
TensorPitches output_pitches(output_dims);
|
||||
CudaAsyncBuffer<fast_divmod> output_div_pitches(this, rank);
|
||||
gsl::span<fast_divmod> div_strides_span = output_div_pitches.CpuSpan();
|
||||
TArray<fast_divmod> output_div_pitches(rank);
|
||||
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
input_stride_span[i] = input_pitches[i];
|
||||
div_strides_span[i] = fast_divmod(gsl::narrow_cast<int>(output_pitches[i]));
|
||||
for (int32_t i = 0; i < rank; ++i) {
|
||||
output_div_pitches[i] = fast_divmod(gsl::narrow_cast<int>(output_pitches[i]));
|
||||
}
|
||||
size_t output_count = Y->Shape().Size();
|
||||
|
||||
if (is_resize_) {
|
||||
CudaAsyncBuffer<int64_t> input_shape(this, X_dims);
|
||||
CudaAsyncBuffer<int64_t> output_shape(this, output_dims);
|
||||
CudaAsyncBuffer<float> roi_vals(this, roi);
|
||||
CudaAsyncBuffer<float> scales_vals(this, scales);
|
||||
TArray<int64_t> input_shape(X_dims);
|
||||
TArray<int64_t> output_shape(output_dims);
|
||||
TArray<float> roi_vals(roi);
|
||||
TArray<float> scales_vals(scales);
|
||||
|
||||
size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims);
|
||||
auto dims_mapping_buffer = GetScratchBuffer<unsigned char>(temp_buffer_size);
|
||||
|
|
@ -86,23 +83,18 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
|
|||
coordinate_transform_mode_, nearest_mode_,
|
||||
dims_mapping);
|
||||
} else {
|
||||
input_strides.CopyToGpu();
|
||||
output_div_pitches.CopyToGpu();
|
||||
TArray<fast_divmod> scales_div(rank);
|
||||
|
||||
CudaAsyncBuffer<fast_divmod> scales_div(this, rank);
|
||||
gsl::span<fast_divmod> scales_div_span = scales_div.CpuSpan();
|
||||
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
scales_div_span[i] = fast_divmod(gsl::narrow_cast<int>(ceil(scales[i])));
|
||||
for (int32_t i = 0; i < rank; ++i) {
|
||||
scales_div[i] = fast_divmod(gsl::narrow_cast<int>(ceil(scales[i])));
|
||||
}
|
||||
scales_div.CopyToGpu();
|
||||
|
||||
UpampleImpl(mode_,
|
||||
rank,
|
||||
(UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0,
|
||||
input_strides.GpuPtr(),
|
||||
output_div_pitches.GpuPtr(),
|
||||
scales_div.GpuPtr(),
|
||||
input_strides,
|
||||
output_div_pitches,
|
||||
scales_div,
|
||||
reinterpret_cast<const CudaT*>(X->template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(Y->template MutableData<T>()),
|
||||
output_count);
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ namespace cuda {
|
|||
|
||||
template <typename T>
|
||||
__global__ void _UpampleNearestKernel(const size_t rank,
|
||||
const int64_t* input_pitches,
|
||||
const fast_divmod* output_div_pitches,
|
||||
const fast_divmod* scales_div,
|
||||
const TArray<int64_t> input_pitches,
|
||||
const TArray<fast_divmod> output_div_pitches,
|
||||
const TArray<fast_divmod> scales_div,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N) {
|
||||
|
|
@ -38,9 +38,9 @@ __global__ void _UpampleNearestKernel(const size_t rank,
|
|||
// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
|
||||
template <typename T>
|
||||
__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2,
|
||||
const int64_t* input_pitches,
|
||||
const fast_divmod* output_div_pitches,
|
||||
const fast_divmod* scales_div,
|
||||
const TArray<int64_t> input_pitches,
|
||||
const TArray<fast_divmod> output_div_pitches,
|
||||
const TArray<fast_divmod> scales_div,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N) {
|
||||
|
|
@ -98,9 +98,9 @@ __global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2,
|
|||
// The following method supports a 2-D input in 'Linear mode'
|
||||
template <typename T>
|
||||
__global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0,
|
||||
const int64_t* input_pitches,
|
||||
const fast_divmod* output_div_pitches,
|
||||
const fast_divmod* scales_div,
|
||||
const TArray<int64_t> input_pitches,
|
||||
const TArray<fast_divmod> output_div_pitches,
|
||||
const TArray<fast_divmod> scales_div,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N) {
|
||||
|
|
@ -152,9 +152,9 @@ template <typename T>
|
|||
void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode,
|
||||
const size_t rank,
|
||||
const int64_t input_dim2,
|
||||
const int64_t* input_pitches,
|
||||
const fast_divmod* output_div_pitches,
|
||||
const fast_divmod* scales_div,
|
||||
const TArray<int64_t>& input_pitches,
|
||||
const TArray<fast_divmod>& output_div_pitches,
|
||||
const TArray<fast_divmod>& scales_div,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N) {
|
||||
|
|
@ -178,9 +178,9 @@ void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode,
|
|||
template void UpampleImpl<T>(const onnxruntime::UpsampleMode upsample_mode, \
|
||||
const size_t rank, \
|
||||
const int64_t input_dim2, \
|
||||
const int64_t* input_pitches, \
|
||||
const fast_divmod* output_div_pitches, \
|
||||
const fast_divmod* scales_div, \
|
||||
const TArray<int64_t>& input_pitches, \
|
||||
const TArray<fast_divmod>& output_div_pitches, \
|
||||
const TArray<fast_divmod>& scales_div, \
|
||||
const T* input_data, \
|
||||
T* output_data, \
|
||||
const size_t N);
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ template <typename T>
|
|||
void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode,
|
||||
const size_t rank,
|
||||
const int64_t input_dim2,
|
||||
const int64_t* input_pitches,
|
||||
const fast_divmod* output_div_pitches,
|
||||
const fast_divmod* scales_div,
|
||||
const TArray<int64_t>& input_pitches,
|
||||
const TArray<fast_divmod>& output_div_pitches,
|
||||
const TArray<fast_divmod>& scales_div,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const size_t N);
|
||||
|
|
|
|||
|
|
@ -69,37 +69,22 @@ struct TernaryElementwisePreparation {
|
|||
const Tensor* b_tensor = nullptr;
|
||||
const Tensor* c_tensor = nullptr;
|
||||
size_t output_rank_or_simple_broadcast = 0; // for no_broadcast cases, output_rank uses SimpleBroadcast enums
|
||||
CudaKernel::CudaAsyncBuffer<int64_t> a_padded_strides; // for a shape == output shape, this is nullptr
|
||||
CudaKernel::CudaAsyncBuffer<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
|
||||
CudaKernel::CudaAsyncBuffer<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
|
||||
CudaKernel::CudaAsyncBuffer<fast_divmod> fdm_output_strides;
|
||||
TArray<int64_t> a_padded_strides; // for a shape == output shape, this is nullptr
|
||||
TArray<int64_t> b_padded_strides; // for b shape == output shape, this is nullptr
|
||||
TArray<int64_t> c_padded_strides; // for c shape == output shape, this is nullptr
|
||||
TArray<fast_divmod> fdm_output_strides;
|
||||
|
||||
TernaryElementwisePreparation(const CudaKernel* op_kernel, const Tensor* a,
|
||||
const Tensor* b, const Tensor* c)
|
||||
: a_padded_strides(op_kernel),
|
||||
b_padded_strides(op_kernel),
|
||||
c_padded_strides(op_kernel),
|
||||
fdm_output_strides(op_kernel),
|
||||
a_tensor(a),
|
||||
b_tensor(b),
|
||||
c_tensor(c) {}
|
||||
|
||||
Status CopyToGpu() {
|
||||
ORT_RETURN_IF_ERROR(a_padded_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(b_padded_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(c_padded_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu());
|
||||
return Status::OK();
|
||||
}
|
||||
TernaryElementwisePreparation(const Tensor* a, const Tensor* b, const Tensor* c)
|
||||
: a_tensor(a), b_tensor(b), c_tensor(c) {}
|
||||
|
||||
Status TernaryElementwiseBroadcastPrepareHelper(const TensorShape& a_shape,
|
||||
const TensorShape& b_shape,
|
||||
const TensorShape& c_shape,
|
||||
const TensorShape& output_shape) {
|
||||
size_t a_rank = a_shape.NumDimensions();
|
||||
size_t b_rank = b_shape.NumDimensions();
|
||||
size_t c_rank = c_shape.NumDimensions();
|
||||
size_t out_rank = std::max(std::max(a_rank, b_rank), c_rank);
|
||||
int32_t a_rank = static_cast<int32_t>(a_shape.NumDimensions());
|
||||
int32_t b_rank = static_cast<int32_t>(b_shape.NumDimensions());
|
||||
int32_t c_rank = static_cast<int32_t>(c_shape.NumDimensions());
|
||||
int32_t out_rank = std::max(std::max(a_rank, b_rank), c_rank);
|
||||
|
||||
// early return when shapes match
|
||||
if (a_shape == b_shape && b_shape == c_shape) {
|
||||
|
|
@ -110,30 +95,47 @@ struct TernaryElementwisePreparation {
|
|||
output_rank_or_simple_broadcast = out_rank;
|
||||
|
||||
if (a_shape != output_shape) {
|
||||
// compute strides with 1 more dim than out_rank, and use strides[0] == strides[1]
|
||||
// to decide if dim0 needs broadcast
|
||||
a_padded_strides.AllocCpuPtr(out_rank + 1);
|
||||
ORT_RETURN_IF_NOT(TensorPitches::Calculate(a_padded_strides.CpuSpan(), a_shape.GetDims()));
|
||||
if (a_shape[0] > 1 && a_rank == out_rank)
|
||||
a_padded_strides.CpuPtr()[0] = 0;
|
||||
TensorPitches a_pitches(a_shape.GetDims());
|
||||
a_padded_strides.size_ = out_rank;
|
||||
auto offset = out_rank - a_rank;
|
||||
for (auto i = offset; i < out_rank; ++i) {
|
||||
// the stride for broadcast dimension is kept as 0
|
||||
if (a_shape.GetDims()[i - offset] != 1) {
|
||||
a_padded_strides[i] = a_pitches[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (b_shape != output_shape) {
|
||||
b_padded_strides.AllocCpuPtr(out_rank + 1);
|
||||
ORT_RETURN_IF_NOT(TensorPitches::Calculate(b_padded_strides.CpuSpan(), b_shape.GetDims()));
|
||||
if (b_shape[0] > 1 && b_rank == out_rank)
|
||||
b_padded_strides.CpuPtr()[0] = 0;
|
||||
TensorPitches b_pitches(b_shape.GetDims());
|
||||
b_padded_strides.size_ = out_rank;
|
||||
auto offset = out_rank - b_rank;
|
||||
for (auto i = offset; i < out_rank; ++i) {
|
||||
// the stride for broadcast dimension is kept as 0
|
||||
if (b_shape.GetDims()[i - offset] != 1) {
|
||||
b_padded_strides[i] = b_pitches[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (c_shape != output_shape) {
|
||||
c_padded_strides.AllocCpuPtr(out_rank + 1);
|
||||
ORT_RETURN_IF_NOT(TensorPitches::Calculate(c_padded_strides.CpuSpan(), c_shape.GetDims()));
|
||||
if (c_shape[0] > 1 && c_rank == out_rank)
|
||||
c_padded_strides.CpuPtr()[0] = 0;
|
||||
TensorPitches c_pitches(c_shape.GetDims());
|
||||
c_padded_strides.size_ = out_rank;
|
||||
auto offset = out_rank - c_rank;
|
||||
for (auto i = offset; i < out_rank; ++i) {
|
||||
// the stride for broadcast dimension is kept as 0
|
||||
if (c_shape.GetDims()[i - offset] != 1) {
|
||||
c_padded_strides[i] = c_pitches[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorPitches output_pitches(output_shape.GetDims());
|
||||
fdm_output_strides.size_ = out_rank;
|
||||
for (auto i = 0; i < out_rank; ++i) {
|
||||
fdm_output_strides[i] = fast_divmod(static_cast<int32_t>(output_pitches[i]));
|
||||
}
|
||||
|
||||
fdm_output_strides.AllocCpuPtr(out_rank);
|
||||
ORT_RETURN_IF_NOT(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_shape.GetDims()));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
|
@ -157,19 +159,18 @@ Status Where<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
if (output_shape.Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
TernaryElementwisePreparation prepare(this, condition, X, Y);
|
||||
TernaryElementwisePreparation prepare(condition, X, Y);
|
||||
ORT_RETURN_IF_ERROR(prepare.TernaryElementwiseBroadcastPrepareHelper(condition_shape, X_shape, Y_shape, output_shape));
|
||||
ORT_RETURN_IF_ERROR(prepare.CopyToGpu());
|
||||
|
||||
WhereImpl<CudaT>(
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
prepare.a_padded_strides.GpuPtr(),
|
||||
prepare.a_padded_strides,
|
||||
reinterpret_cast<const bool*>(prepare.a_tensor->template Data<bool>()),
|
||||
prepare.b_padded_strides.GpuPtr(),
|
||||
prepare.b_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.b_tensor->template Data<T>()),
|
||||
prepare.c_padded_strides.GpuPtr(),
|
||||
prepare.c_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.c_tensor->template Data<T>()),
|
||||
prepare.fdm_output_strides.GpuPtr(),
|
||||
prepare.fdm_output_strides,
|
||||
reinterpret_cast<CudaT*>(output_tensor->template MutableData<T>()),
|
||||
output_tensor->Shape().Size());
|
||||
|
||||
|
|
|
|||
|
|
@ -10,46 +10,43 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
// broadcast by computing output coordinate from offset, using fast_divmod
|
||||
template <typename T>
|
||||
template <typename T, bool cond_need_compute, bool x_need_compute, bool y_need_compute>
|
||||
__global__ void _TenaryElementWise(
|
||||
size_t output_rank,
|
||||
const int64_t* cond_padded_strides,
|
||||
const TArray<int64_t> cond_padded_strides,
|
||||
const bool* cond_data,
|
||||
const int64_t* x_padded_strides,
|
||||
const TArray<int64_t> x_padded_strides,
|
||||
const T* x_data,
|
||||
const int64_t* y_padded_strides,
|
||||
const TArray<int64_t> y_padded_strides,
|
||||
const T* y_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod> fdm_output_strides,
|
||||
T* output_data,
|
||||
CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
bool cond_need_compute = cond_padded_strides != NULL;
|
||||
bool x_need_compute = x_padded_strides != NULL;
|
||||
bool y_need_compute = y_padded_strides != NULL;
|
||||
CUDA_LONG cond_index = (cond_need_compute ? 0 : id);
|
||||
CUDA_LONG x_index = (x_need_compute ? 0 : id);
|
||||
CUDA_LONG y_index = (y_need_compute ? 0 : id);
|
||||
|
||||
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
|
||||
CUDA_LONG offset = id;
|
||||
for (int dim = 0; dim < output_rank; dim++) {
|
||||
for (auto dim = 0; dim < fdm_output_strides.GetCapacity(); dim++) {
|
||||
if (dim >= output_rank) {
|
||||
break;
|
||||
}
|
||||
|
||||
int q, r;
|
||||
fdm_output_strides[dim].divmod(offset, q, r);
|
||||
// compute index increase based on stride and broadcast
|
||||
// note that stride[i-1] == stride[i] means dim[i] is 1 (broadcasting)
|
||||
|
||||
if (cond_need_compute) {
|
||||
if (cond_padded_strides[dim] != cond_padded_strides[dim + 1])
|
||||
cond_index += static_cast<int>(cond_padded_strides[dim + 1]) * q;
|
||||
cond_index += static_cast<int>(cond_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
if (x_need_compute) {
|
||||
if (x_padded_strides[dim] != x_padded_strides[dim + 1])
|
||||
x_index += static_cast<int>(x_padded_strides[dim + 1]) * q;
|
||||
x_index += static_cast<int>(x_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
if (y_need_compute) {
|
||||
if (y_padded_strides[dim] != y_padded_strides[dim + 1])
|
||||
y_index += static_cast<int>(y_padded_strides[dim + 1]) * q;
|
||||
y_index += static_cast<int>(y_padded_strides[dim]) * q;
|
||||
}
|
||||
|
||||
offset = r;
|
||||
|
|
@ -73,13 +70,13 @@ __global__ void _TenaryElementWiseSimple(
|
|||
template <typename T>
|
||||
void WhereImpl(
|
||||
size_t output_rank_or_simple_broadcast,
|
||||
const int64_t* cond_padded_strides,
|
||||
const TArray<int64_t>& cond_padded_strides,
|
||||
const bool* cond_data,
|
||||
const int64_t* x_padded_strides,
|
||||
const TArray<int64_t>& x_padded_strides,
|
||||
const T* x_data,
|
||||
const int64_t* y_padded_strides,
|
||||
const TArray<int64_t>& y_padded_strides,
|
||||
const T* y_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
T* output_data,
|
||||
size_t count) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(count) / GridDim::maxThreadsPerBlock));
|
||||
|
|
@ -93,7 +90,8 @@ void WhereImpl(
|
|||
output_data,
|
||||
N);
|
||||
} else {
|
||||
_TenaryElementWise<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
if (cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, true, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
|
|
@ -104,18 +102,103 @@ void WhereImpl(
|
|||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, true, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, false, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (!cond_padded_strides.size_ && x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, false, true, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (cond_padded_strides.size_ && !x_padded_strides.size_ && !y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, true, false, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (!cond_padded_strides.size_ && x_padded_strides.size_ && !y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, false, true, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else if (!cond_padded_strides.size_ && !x_padded_strides.size_ && y_padded_strides.size_) {
|
||||
_TenaryElementWise<T, false, false, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
} else {
|
||||
_TenaryElementWise<T, false, false, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
output_rank_or_simple_broadcast,
|
||||
cond_padded_strides,
|
||||
cond_data,
|
||||
x_padded_strides,
|
||||
x_data,
|
||||
y_padded_strides,
|
||||
y_data,
|
||||
fdm_output_strides,
|
||||
output_data,
|
||||
N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void WhereImpl<T>(size_t output_rank_or_simple_broadcast, \
|
||||
const int64_t* cond_padded_strides, \
|
||||
const TArray<int64_t>& cond_padded_strides, \
|
||||
const bool* cond_data, \
|
||||
const int64_t* x_padded_strides, \
|
||||
const TArray<int64_t>& x_padded_strides, \
|
||||
const T* x_data, \
|
||||
const int64_t* y_padded_strides, \
|
||||
const TArray<int64_t>& y_padded_strides, \
|
||||
const T* y_data, \
|
||||
const fast_divmod* fdm_output_strides, \
|
||||
const TArray<fast_divmod>& fdm_output_strides, \
|
||||
T* output_data, \
|
||||
size_t count);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ namespace cuda {
|
|||
template <typename T>
|
||||
void WhereImpl(
|
||||
size_t output_rank_or_simple_broadcast,
|
||||
const int64_t* cond_padded_strides,
|
||||
const TArray<int64_t>& cond_padded_strides,
|
||||
const bool* cond_data,
|
||||
const int64_t* x_padded_strides,
|
||||
const TArray<int64_t>& x_padded_strides,
|
||||
const T* x_data,
|
||||
const int64_t* y_padded_strides,
|
||||
const TArray<int64_t>& y_padded_strides,
|
||||
const T* y_data,
|
||||
const fast_divmod* fdm_output_strides,
|
||||
const TArray<fast_divmod>& fdm_output_strides,
|
||||
T* output_data,
|
||||
size_t count);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue