mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
implement dynamic slice cuda (#286)
* implement dynamic slice cuda * add template parameter * add delaration * init base class * exclude case from cuda * use cuda mapped type * separate function implementation * add cpy logic * refactor * add type check * use InputMemoryType * merge functions
This commit is contained in:
parent
98a92547bf
commit
fa0ea9a273
6 changed files with 70 additions and 49 deletions
|
|
@ -113,34 +113,28 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, typename Tind, bool dynamic>
|
||||
void Slice<T, Tind, dynamic>::FillVectors(const OpKernelContext* context,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes) const {
|
||||
ORT_ENFORCE(context->Input<Tensor>(1) != nullptr, "Required starts input is missing");
|
||||
ORT_ENFORCE(context->Input<Tensor>(2) != nullptr, "Required ends input is missing");
|
||||
template <typename Tind>
|
||||
void SliceBase::FillVectorsFromInput(const OpKernelContext* context,
|
||||
std::vector<int64_t>& input_starts,
|
||||
std::vector<int64_t>& input_ends,
|
||||
std::vector<int64_t>& input_axes) const {
|
||||
auto stat_tensor = context->Input<Tensor>(1);
|
||||
auto ends_tensor = context->Input<Tensor>(2);
|
||||
auto axes_tensor = context->Input<Tensor>(3);
|
||||
|
||||
auto starts_tensor_ptr = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(starts_tensor_ptr->Shape().NumDimensions() == 1, "Starts input must be a 1-D array");
|
||||
input_starts = std::vector<int64_t> (starts_tensor_ptr->Data<Tind>(),
|
||||
starts_tensor_ptr->Data<Tind>() +
|
||||
starts_tensor_ptr->Shape().Size());
|
||||
|
||||
auto ends_tensor_ptr = context->Input<Tensor>(2);
|
||||
ORT_ENFORCE(ends_tensor_ptr->Shape().NumDimensions() == 1, "ends input must be a 1-D array");
|
||||
input_ends = std::vector<int64_t> (ends_tensor_ptr->Data<Tind>(),
|
||||
ends_tensor_ptr->Data<Tind>() +
|
||||
ends_tensor_ptr->Shape().Size());
|
||||
ORT_ENFORCE (nullptr != stat_tensor && stat_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array" );
|
||||
ORT_ENFORCE (nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "ends must be a 1-D array" );
|
||||
ORT_ENFORCE (stat_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
|
||||
ORT_ENFORCE (nullptr == axes_tensor || stat_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
|
||||
|
||||
ORT_ENFORCE(input_starts.size() == input_ends.size(), "Found mismatch between starts and ends input");
|
||||
|
||||
if (context->Input<Tensor>(3) != nullptr) {
|
||||
auto axes_tensor_ptr = context->Input<Tensor>(3);
|
||||
input_axes = std::vector<int64_t> (axes_tensor_ptr->Data<Tind>(),
|
||||
axes_tensor_ptr->Data<Tind>() +
|
||||
axes_tensor_ptr->Shape().Size());
|
||||
ORT_ENFORCE(input_axes.size() == input_starts.size(), "Axes input is invalid");
|
||||
auto size = stat_tensor->Shape().Size();
|
||||
input_starts.resize(size);
|
||||
std::copy(stat_tensor->Data<Tind>(), stat_tensor->Data<Tind>() + size, input_starts.begin());
|
||||
input_ends.resize(size);
|
||||
std::copy(ends_tensor->Data<Tind>(), ends_tensor->Data<Tind>() + size, input_ends.begin());
|
||||
if (nullptr != axes_tensor) {
|
||||
input_axes.resize(size);
|
||||
std::copy(axes_tensor->Data<Tind>(), axes_tensor->Data<Tind>() + size, input_axes.begin());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -158,7 +152,7 @@ Status Slice<T, Tind, dynamic>::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
if (dynamic) {
|
||||
std::vector<int64_t> input_starts, input_ends, input_axes;
|
||||
FillVectors(ctx, input_starts, input_ends, input_axes);
|
||||
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ class SliceBase {
|
|||
auto has_ends = info.GetAttrs("ends", attr_ends_).IsOK();
|
||||
auto has_axes = info.GetAttrs("axes", attr_axes_).IsOK();
|
||||
ORT_ENFORCE(has_starts && has_ends && attr_starts_.size() == attr_ends_.size(),
|
||||
"Missing or invalid starts and ends attribute");
|
||||
"Missing or invalid starts and ends attribute");
|
||||
ORT_ENFORCE(!has_axes || attr_axes_.size() == attr_starts_.size(),
|
||||
"Invalid axes attribute");
|
||||
"Invalid axes attribute");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -28,6 +28,11 @@ class SliceBase {
|
|||
const std::vector<int64_t>& input_dimensions,
|
||||
std::vector<int64_t>& starts,
|
||||
std::vector<int64_t>& output_dims) const;
|
||||
template<typename Tind>
|
||||
void FillVectorsFromInput(const OpKernelContext* context,
|
||||
std::vector<int64_t>& raw_starts,
|
||||
std::vector<int64_t>& raw_ends,
|
||||
std::vector<int64_t>& raw_axes) const;
|
||||
|
||||
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
|
||||
};
|
||||
|
|
@ -36,11 +41,6 @@ template <typename T, typename Tind, bool dynamic>
|
|||
struct Slice final : public OpKernel, public SliceBase {
|
||||
Slice(const OpKernelInfo& info) : OpKernel(info), SliceBase(info, dynamic) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
private:
|
||||
void FillVectors(const OpKernelContext* context,
|
||||
std::vector<int64_t>& raw_starts,
|
||||
std::vector<int64_t>& raw_ends,
|
||||
std::vector<int64_t>& raw_axes) const;
|
||||
}; // namespace onnxruntime
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -493,7 +493,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, LSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, LSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, MLFloat16, LSTM);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, Compress);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, Upsample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, Upsample);
|
||||
|
|
@ -753,7 +756,10 @@ static void RegisterCudaKernels(std::function<void(KernelCreateInfo&&)> fn) {
|
|||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, LSTM)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, LSTM)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, MLFloat16, LSTM)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, Slice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, Compress)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, Upsample)>());
|
||||
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, Upsample)>());
|
||||
|
|
|
|||
|
|
@ -8,15 +8,27 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Slice,
|
||||
kOnnxDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
Slice);
|
||||
#define REGISTER_TYPED_SLICE(NAME, TIND, DYNAMIC) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
NAME, \
|
||||
kOnnxDomain, \
|
||||
1, \
|
||||
TIND, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder().InputMemoryType<OrtMemTypeCPUInput>(1). \
|
||||
InputMemoryType<OrtMemTypeCPUInput>(2). \
|
||||
InputMemoryType<OrtMemTypeCPUInput>(3). \
|
||||
TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()). \
|
||||
TypeConstraint("Tind", DataTypeImpl::GetTensorType<TIND>()), \
|
||||
Slice<TIND,DYNAMIC>);
|
||||
|
||||
Status Slice::ComputeInternal(OpKernelContext* ctx) const {
|
||||
REGISTER_TYPED_SLICE(Slice, int32_t, false)
|
||||
REGISTER_TYPED_SLICE(Slice, int64_t, false)
|
||||
REGISTER_TYPED_SLICE(DynamicSlice, int32_t, true )
|
||||
REGISTER_TYPED_SLICE(DynamicSlice, int64_t, true )
|
||||
|
||||
template<typename Tind, bool dynamic>
|
||||
Status Slice<Tind, dynamic>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
auto input_tensor = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(nullptr != input_tensor);
|
||||
auto& input_dimensions = input_tensor->Shape().GetDims();
|
||||
|
|
@ -26,9 +38,16 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const {
|
|||
std::vector<int64_t> starts(dimension_count, 0);
|
||||
std::vector<int64_t> output_dims(input_dimensions);
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
dimension_count, input_dimensions,
|
||||
starts, output_dims));
|
||||
if (dynamic) {
|
||||
std::vector<int64_t> input_starts, input_ends, input_axes;
|
||||
FillVectorsFromInput<Tind>(ctx, input_starts, input_ends, input_axes);
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PrepareForCompute(attr_starts_, attr_ends_, attr_axes_,
|
||||
dimension_count, input_dimensions, starts, output_dims));
|
||||
}
|
||||
|
||||
TensorShape output_shape(output_dims);
|
||||
auto output_tensor = ctx->Output(0, output_shape);
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename Tind, bool dynamic>
|
||||
class Slice final : public CudaKernel, public SliceBase {
|
||||
public:
|
||||
Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info) {}
|
||||
Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -22,12 +22,14 @@ TEST(DynamicSliceTest, dynamic_slice_varied_types) {
|
|||
test2.AddOutput <int64_t> ("output", {2,2}, {5LL,6LL,8LL,9LL});
|
||||
test2.Run();
|
||||
|
||||
#ifndef USE_CUDA
|
||||
OpTester test3("DynamicSlice", 1);
|
||||
test3.AddInput <std::string> ("data", {3,3}, {"a","b","c","d","e","f","g","h","i"});
|
||||
test3.AddInput <int64_t> ("starts", {2}, {1,1});
|
||||
test3.AddInput <int64_t> ("ends", {2}, {3,3});
|
||||
test3.AddOutput <std::string> ("output", {2,2}, {"e","f","h","i"});
|
||||
test3.Run();
|
||||
#endif
|
||||
|
||||
OpTester test4("DynamicSlice", 1);
|
||||
test4.AddInput <float> ("data", {3,3}, {1.1f,2.2f,3.3f,4.4f,5.5f,6.6f,7.7f,8.8f,9.9f});
|
||||
|
|
@ -119,6 +121,5 @@ TEST(DynamicSliceTest, dynamic_slice_full_axes) {
|
|||
test2.AddOutput <int32_t> ("output", {1,2,1}, {5,8});
|
||||
test2.Run();
|
||||
}
|
||||
|
||||
} // namespace Test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue