mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Add CUDA Expand operator (#1292)
* Add CUDA expand operator * Reset counter variables when striding * Reset counter variables when striding * use fast_divmod and other PR comments * Fix merge variable rename * Fix indentation per PR comment * Remove maxpool_argmax * Reduce number of type templates for Expand operator * removed all types * Commit updated cuda_execution_provider.cc
This commit is contained in:
parent
a79ab5ec5b
commit
59de37af1f
8 changed files with 237 additions and 2 deletions
|
|
@ -344,6 +344,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, Expand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater);
|
||||
|
|
@ -668,6 +669,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater)>,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
|
||||
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
|
||||
size_t lhs_rank = lhs_shape.NumDimensions();
|
||||
size_t rhs_rank = rhs_shape.NumDimensions();
|
||||
size_t out_rank = std::max(lhs_rank, rhs_rank);
|
||||
|
|
|
|||
76
onnxruntime/core/providers/cuda/tensor/expand.cc
Normal file
76
onnxruntime/core/providers/cuda/tensor/expand.cc
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "expand.h"
|
||||
#include "expand_impl.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
Status Expand::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const auto& input0 = *ctx->Input<Tensor>(0);
|
||||
const auto& input1 = *ctx->Input<Tensor>(1);
|
||||
int device_id = GetDeviceId();
|
||||
|
||||
// new shape to be expanded to
|
||||
const auto* p_shape = input1.template Data<int64_t>();
|
||||
std::vector<int64_t> output_dims{p_shape, p_shape + input1.Shape().Size()};
|
||||
TensorShape output_shape(output_dims);
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input0.Shape(), output_dims, output_shape));
|
||||
auto rank = output_shape.NumDimensions();
|
||||
auto& output_tensor = *ctx->Output(0, output_shape);
|
||||
auto input_shape = input0.Shape().GetDims();
|
||||
|
||||
// pad input_dims with 1 to make ranks match
|
||||
for (int i = 0; i < rank - input_shape.size(); i++) {
|
||||
input_shape.insert(input_shape.begin(), 1);
|
||||
}
|
||||
|
||||
// create fast_divmod using dimension values
|
||||
CudaAsyncBuffer<fast_divmod> fdm_input_dims(this, device_id, rank);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_dims(this, device_id, rank);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_subdim_size(this, device_id, rank);
|
||||
{
|
||||
auto in_span = fdm_input_dims.CpuSpan();
|
||||
auto out_span = fdm_output_dims.CpuSpan();
|
||||
auto sdm_span = fdm_output_subdim_size.CpuSpan();
|
||||
auto subdim_size = output_shape.Size();
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
in_span[i] = fast_divmod(static_cast<int>(input_shape[i]));
|
||||
out_span[i] = fast_divmod(static_cast<int>(output_shape[i]));
|
||||
subdim_size /= output_shape[i];
|
||||
sdm_span[i] = static_cast<int>(subdim_size);
|
||||
}
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(fdm_input_dims.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_output_dims.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_output_subdim_size.CopyToGpu());
|
||||
|
||||
ExpandImpl(
|
||||
input0.DataType()->Size(),
|
||||
output_shape.NumDimensions(),
|
||||
output_shape.Size(),
|
||||
input0.Shape().Size(),
|
||||
input0.DataRaw(),
|
||||
output_tensor.MutableDataRaw(),
|
||||
fdm_input_dims.GpuPtr(),
|
||||
fdm_output_dims.GpuPtr(),
|
||||
fdm_output_subdim_size.GpuPtr());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Expand,
|
||||
kOnnxDomain,
|
||||
8,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(1),
|
||||
Expand);
|
||||
|
||||
} // namespace cuda
|
||||
}; // namespace onnxruntime
|
||||
25
onnxruntime/core/providers/cuda/tensor/expand.h
Normal file
25
onnxruntime/core/providers/cuda/tensor/expand.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
class Expand final : public CudaKernel {
|
||||
public:
|
||||
Expand(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
Status ComputeOutputShape(
|
||||
const std::string& node_name,
|
||||
const TensorShape& lhs_shape,
|
||||
const TensorShape& rhs_shape,
|
||||
TensorShape& out_shape);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
96
onnxruntime/core/providers/cuda/tensor/expand_impl.cu
Normal file
96
onnxruntime/core/providers/cuda/tensor/expand_impl.cu
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "expand_impl.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void ExpandKernel(
|
||||
const size_t rank,
|
||||
const size_t N,
|
||||
const size_t N_input,
|
||||
const T* input_data,
|
||||
T* output_data,
|
||||
const fast_divmod* fdm_input_dims,
|
||||
const fast_divmod* fdm_output_dims,
|
||||
const fast_divmod* fdm_output_subdim_size) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
|
||||
// initialize
|
||||
auto output_index = id;
|
||||
auto input_index = 0;
|
||||
auto input_subdim_size = N_input;
|
||||
auto out_coord = output_index;
|
||||
// use striding when tensor is larger than grid
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
|
||||
// translate indices to coordinates. copy expanded dims from source
|
||||
while (output_index < N) {
|
||||
for (int64_t i = 0; i < rank; i++) {
|
||||
input_subdim_size = fdm_input_dims[i].div(input_subdim_size);
|
||||
auto new_out_coord = fdm_output_subdim_size[i].div(out_coord);
|
||||
auto in_coord = (new_out_coord > (fdm_input_dims[i].d_ - 1)) ? fdm_input_dims[i].d_ - 1 : new_out_coord;
|
||||
input_index += input_subdim_size * in_coord;
|
||||
out_coord -= new_out_coord * fdm_output_subdim_size[i].d_;
|
||||
}
|
||||
output_data[output_index] = input_data[input_index];
|
||||
output_index += stride;
|
||||
out_coord = output_index;
|
||||
input_subdim_size = N_input;
|
||||
input_index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Status ExpandImpl(
|
||||
const size_t element_size,
|
||||
const size_t rank,
|
||||
const size_t N,
|
||||
const size_t N_input,
|
||||
const void* input_data,
|
||||
void* output_data,
|
||||
const fast_divmod* fdm_input_dims,
|
||||
const fast_divmod* fdm_output_dims,
|
||||
const fast_divmod* fdm_output_subdim_size) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(uint8_t):
|
||||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, N, N_input,
|
||||
reinterpret_cast<const ToCudaType<uint8_t>::MappedType*>(input_data),
|
||||
reinterpret_cast<ToCudaType<uint8_t>::MappedType*>(output_data),
|
||||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
|
||||
break;
|
||||
case sizeof(uint16_t):
|
||||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, N, N_input,
|
||||
reinterpret_cast<const ToCudaType<uint16_t>::MappedType*>(input_data),
|
||||
reinterpret_cast<ToCudaType<uint16_t>::MappedType*>(output_data),
|
||||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
|
||||
break;
|
||||
case sizeof(uint32_t):
|
||||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, N, N_input,
|
||||
reinterpret_cast<const ToCudaType<uint32_t>::MappedType*>(input_data),
|
||||
reinterpret_cast<ToCudaType<uint32_t>::MappedType*>(output_data),
|
||||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
rank, N, N_input,
|
||||
reinterpret_cast<const ToCudaType<uint64_t>::MappedType*>(input_data),
|
||||
reinterpret_cast<ToCudaType<uint64_t>::MappedType*>(output_data),
|
||||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
|
||||
break;
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Expand operator");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
25
onnxruntime/core/providers/cuda/tensor/expand_impl.h
Normal file
25
onnxruntime/core/providers/cuda/tensor/expand_impl.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/common/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
Status ExpandImpl(
|
||||
const size_t element_size,
|
||||
const size_t shape_rank,
|
||||
const size_t N,
|
||||
const size_t N_input,
|
||||
const void* input_data,
|
||||
void* output_data,
|
||||
const fast_divmod* fdm_input_dims,
|
||||
const fast_divmod* fdm_output_dims,
|
||||
const fast_divmod* fdm_output_subdim_size);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
template <typename T >
|
||||
__global__ void _SliceKernel(const int32_t dimension_count,
|
||||
const int64_t* starts,
|
||||
const int64_t* steps,
|
||||
|
|
|
|||
|
|
@ -950,6 +950,17 @@ TEST(MathOpTest, Expand_8_1x3_int64) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x1x3x1_int64) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<int64_t>("data_0", {1, 3, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
test.AddInput<int64_t>("data_1", {4}, {3, 1, 3, 1});
|
||||
test.AddOutput<int64_t>("result", {3, 3, 3, 3},
|
||||
{1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
|
||||
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
|
||||
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MathOpTest, Expand_8_3x3_float16) {
|
||||
OpTester test("Expand", 8);
|
||||
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))});
|
||||
|
|
|
|||
Loading…
Reference in a new issue