Cuda pad() for opset 11 (#2490)

* Cuda pad opset 11.

* Handle type conversion issue in building.
This commit is contained in:
Zhang Lei 2019-12-02 16:28:17 -08:00 committed by GitHub
parent b9faa0b6fd
commit 784eca0dcd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 10 deletions

View file

@ -390,7 +390,7 @@ Status Pad<T>::Compute(OpKernelContext* ctx) const {
}
}
T value = 0;
T value = static_cast<T>(0);
const Tensor* value_tensor = ctx->Input<Tensor>(2);
if (nullptr != value_tensor) {
ORT_ENFORCE(utils::IsPrimitiveDataType<T>(value_tensor->DataType()) &&

View file

@ -660,6 +660,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, uint8_t, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Clip);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t, Equal);
@ -1118,6 +1121,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, uint8_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Clip)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int32_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, int64_t, Equal)>,

View file

@ -18,13 +18,82 @@ namespace cuda {
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Pad<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Pad, \
kOnnxDomain, \
11, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.InputMemoryType<OrtMemTypeCPUInput>(1) \
.InputMemoryType<OrtMemTypeCPUInput>(2) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Pad<T>);
template <typename T>
typename ToCudaType<T>::MappedType ToCudaValue(const T& value) {
return value;
}
template<>
typename ToCudaType<MLFloat16>::MappedType ToCudaValue<MLFloat16>(const MLFloat16& value) {
return *reinterpret_cast<const typename ToCudaType<MLFloat16>::MappedType *>(&value.val);
}
template <typename T>
Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const auto& input_tensor = *ctx->Input<Tensor>(0);
auto const& input_shape = input_tensor.Shape();
auto dimension_count = input_shape.NumDimensions();
const std::vector<int64_t>* p_pads = &pads_;
const std::vector<int64_t>* p_slices = &slices_;
CudaT value = ToCudaType<T>::FromFloat(value_);
// kOnnxDomain Pad opset >= 11 (Or) kMsDomain opset == 1
std::vector<int64_t> pads;
std::vector<int64_t> slices;
if (is_dynamic_) {
const Tensor& pads_tensor = *ctx->Input<Tensor>(1);
const std::vector<int64_t>& pads_tensor_dims = pads_tensor.Shape().GetDims();
ORT_ENFORCE(utils::IsPrimitiveDataType<int64_t>(pads_tensor.DataType()),
"Pads tensor should be an INT64 tensor");
ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1),
"Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]");
const int64_t* pads_tensor_raw_data = pads_tensor.template Data<int64_t>();
size_t pads_size = static_cast<size_t>(pads_tensor.Shape().Size());
ORT_ENFORCE(pads_size == 2 * dimension_count,
"Pads tensor size should be equal to twice the input dimension count ");
pads.reserve(2 * dimension_count);
for (size_t i = 0; i < pads_size; ++i) {
pads.push_back(pads_tensor_raw_data[i]);
}
// Separate out any negative pads into the slices array
slices.resize(pads.size(), 0);
for (size_t index = 0; index < pads.size(); index++) {
if (pads[index] < 0) {
slices[index] = pads[index];
pads[index] = 0;
}
}
T raw_value(0);
const Tensor* value_tensor = ctx->Input<Tensor>(2);
if (nullptr != value_tensor) {
ORT_ENFORCE(utils::IsPrimitiveDataType<T>(value_tensor->DataType()) &&
value_tensor->Shape().Size() == 1,
"Value tensor should be a 1D tensor of size 1 with the same type as that of the input tensor");
raw_value = value_tensor->template Data<T>()[0];
value = ToCudaValue<T>(raw_value);
}
p_pads = &pads;
p_slices = &slices;
}
CudaAsyncBuffer<int64_t> input_dims(this, input_shape.GetDims());
CudaAsyncBuffer<int64_t> input_strides(this, dimension_count);
CudaAsyncBuffer<int64_t> lower_pads(this, dimension_count);
@ -33,15 +102,14 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
TensorPitches::Calculate(input_strides.CpuSpan(), input_shape.GetDims());
std::vector<int64_t> output_dims(input_shape.GetDims());
ORT_ENFORCE(dimension_count * 2 == pads_.size(), "'pads' attribute has wrong number of values");
ORT_ENFORCE(dimension_count * 2 == p_pads->size(), "'pads' attribute has wrong number of values");
// Calculate output dimensions, and handle any negative padding
auto lower_pads_span = lower_pads.CpuSpan();
auto upper_pads_span = upper_pads.CpuSpan();
for (size_t i = 0; i < dimension_count; i++) {
lower_pads_span[i] = pads_[i] + slices_[i];
upper_pads_span[i] = pads_[i + dimension_count] + slices_[i + dimension_count];
lower_pads_span[i] = (*p_pads)[i] + (*p_slices)[i];
upper_pads_span[i] = (*p_pads)[i + dimension_count] + (*p_slices)[i + dimension_count];
output_dims[i] += lower_pads_span[i] + upper_pads_span[i];
}
TensorShape output_shape(output_dims);
@ -65,7 +133,7 @@ Status Pad<T>::ComputeInternal(OpKernelContext* ctx) const {
input_strides.GpuPtr(),
lower_pads.GpuPtr(),
upper_pads.GpuPtr(),
value_,
value,
static_cast<int>(mode_),
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(input_tensor.template Data<T>()),
fdm_output_strides.GpuPtr(),

View file

@ -6,6 +6,8 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/tensor/pad.h"
using onnxruntime::PadBase;
namespace onnxruntime {
namespace cuda {

View file

@ -21,7 +21,7 @@ __global__ void _PadKernel(
const int64_t* input_strides,
const int64_t* lower_pads,
const int64_t* upper_pads,
const float pad_value,
const T pad_value,
const T* input_data,
const fast_divmod* fdm_output_strides,
T* output_data,
@ -74,7 +74,7 @@ void PadImpl(
const int64_t* input_strides,
const int64_t* lower_pads,
const int64_t* upper_pads,
const float pad_value,
const T pad_value,
const int pad_mode,
const T* input_data,
const fast_divmod* fdm_output_strides,
@ -104,7 +104,7 @@ void PadImpl(
}
#define SPECIALIZED_IMPL(T) \
template void PadImpl<T>(const size_t shape_rank, const int64_t* input_dims, const int64_t* input_strides, const int64_t* lower_pads, const int64_t* upper_pads, const float pad_value, const int pad_mode, const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, const size_t N);
template void PadImpl<T>(const size_t shape_rank, const int64_t* input_dims, const int64_t* input_strides, const int64_t* lower_pads, const int64_t* upper_pads, const T pad_value, const int pad_mode, const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, const size_t N);
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)

View file

@ -15,7 +15,7 @@ void PadImpl(
const int64_t* input_strides,
const int64_t* lower_pads,
const int64_t* upper_pads,
const float pad_value,
const T pad_value,
const int pad_mode,
const T* input_data,
const fast_divmod* fdm_output_strides,