mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Change CUDA implementation of Transpose to support all fixed size tensor types (#2387)
* Change CUDA implementation of Transpose to not use a typed kernel so we can support more types with minimum binary size. Add support for 8, 16, 32 and 64 bit types. Add unit tests. Add method so the implementation can be called directly (will be used by CUDA Scan very soon). * Disable TensorRT for MLFloat16 and int8 unit tests. * Address PR comment and add support for calling cublas implementation if type is mlfloat16.
This commit is contained in:
parent
109b3cb450
commit
6cc57721f4
6 changed files with 201 additions and 103 deletions
|
|
@ -486,9 +486,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Sh
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Transpose);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Transpose);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Transpose);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, InstanceNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, InstanceNormalization);
|
||||
|
|
@ -863,9 +861,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, InstanceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, InstanceNormalization)>,
|
||||
|
|
|
|||
|
|
@ -9,19 +9,16 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Transpose, \
|
||||
kOnnxDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Transpose<T>);
|
||||
ONNX_OPERATOR_KERNEL_EX(Transpose,
|
||||
kOnnxDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
|
||||
Transpose);
|
||||
|
||||
// special case acceleration using cublas matrix transpose
|
||||
std::tuple<int, int> TryTransposeWithCublas(const std::vector<size_t>& perm, const TensorShape& input_shape) {
|
||||
static std::tuple<int, int> TryTransposeWithCublas(const std::vector<size_t>& perm, const TensorShape& input_shape) {
|
||||
int M = 0;
|
||||
int N = 0;
|
||||
|
||||
|
|
@ -47,7 +44,72 @@ std::tuple<int, int> TryTransposeWithCublas(const std::vector<size_t>& perm, con
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
Status TransposeWithCublas(cublasHandle_t cublas_handle, const Tensor& input, Tensor& output, int M, int N) {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
CudaT one = ToCudaType<T>::FromFloat(1.0f);
|
||||
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
|
||||
const CudaT* input_data = reinterpret_cast<const CudaT*>(input.Data<T>());
|
||||
CudaT* output_data = reinterpret_cast<CudaT*>(output.MutableData<T>());
|
||||
CUBLAS_RETURN_IF_ERROR(
|
||||
cublasTransposeHelper(cublas_handle,
|
||||
CUBLAS_OP_T, CUBLAS_OP_T, M, N,
|
||||
&one,
|
||||
input_data,
|
||||
N,
|
||||
&zero,
|
||||
input_data,
|
||||
N,
|
||||
output_data,
|
||||
M));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Transpose::DoTranspose(const Transpose& kernel,
|
||||
const std::vector<size_t>& permutations, const Tensor& input, Tensor& output) {
|
||||
// special case when there is a dim value of 0 in the shape.
|
||||
if (output.Shape().Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
auto element_type = input.GetElementType();
|
||||
if (element_type == utils::GetONNXTensorElementDataType<float>() ||
|
||||
element_type == utils::GetONNXTensorElementDataType<double>() ||
|
||||
element_type == utils::GetONNXTensorElementDataType<MLFloat16>()) {
|
||||
auto mn = TryTransposeWithCublas(permutations, input.Shape());
|
||||
int M = std::get<0>(mn);
|
||||
int N = std::get<1>(mn);
|
||||
if (M != 0 && N != 0) {
|
||||
if (element_type == utils::GetONNXTensorElementDataType<float>()) {
|
||||
return TransposeWithCublas<float>(kernel.CublasHandle(), input, output, M, N);
|
||||
} else if (element_type == utils::GetONNXTensorElementDataType<double>()) {
|
||||
return TransposeWithCublas<double>(kernel.CublasHandle(), input, output, M, N);
|
||||
} else {
|
||||
return TransposeWithCublas<MLFloat16>(kernel.CublasHandle(), input, output, M, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& input_dims = input.Shape().GetDims();
|
||||
const std::vector<int64_t>& output_dims = output.Shape().GetDims();
|
||||
|
||||
auto rank = input_dims.size();
|
||||
CudaAsyncBuffer<int64_t> input_strides(&kernel, rank);
|
||||
CudaAsyncBuffer<size_t> perm(&kernel, permutations);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_strides(&kernel, rank);
|
||||
ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims));
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));
|
||||
|
||||
ORT_RETURN_IF_ERROR(input_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(perm.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu());
|
||||
|
||||
size_t element_size = input.DataType()->Size();
|
||||
auto status = TransposeImpl(element_size, rank, input_strides.GpuPtr(), perm.GpuPtr(), input.DataRaw(),
|
||||
fdm_output_strides.GpuPtr(), output.MutableDataRaw(), output.Shape().Size());
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
Status Transpose::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* X_ptr = ctx->Input<Tensor>(0);
|
||||
if (X_ptr == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
const Tensor& X = *X_ptr;
|
||||
|
|
@ -65,66 +127,8 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
TensorShape output_shape{output_dims};
|
||||
Tensor* Y = ctx->Output(0, output_shape);
|
||||
|
||||
// special case when there is a dim value of 0 in the shape.
|
||||
if (output_shape.Size() == 0)
|
||||
return Status::OK();
|
||||
|
||||
auto mn = TryTransposeWithCublas(*p_perm, input_shape);
|
||||
int M = std::get<0>(mn);
|
||||
int N = std::get<1>(mn);
|
||||
if (M != 0 && N != 0) {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
CudaT one = ToCudaType<T>::FromFloat(1.0f);
|
||||
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
|
||||
const CudaT* input_data = reinterpret_cast<const CudaT*>(X.template Data<T>());
|
||||
CudaT* output_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
|
||||
CUBLAS_RETURN_IF_ERROR(
|
||||
cublasTransposeHelper(
|
||||
CublasHandle(),
|
||||
CUBLAS_OP_T,
|
||||
CUBLAS_OP_T,
|
||||
M,
|
||||
N,
|
||||
&one,
|
||||
input_data,
|
||||
N,
|
||||
&zero,
|
||||
input_data,
|
||||
N,
|
||||
output_data,
|
||||
M));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CudaAsyncBuffer<int64_t> input_strides(this, rank);
|
||||
CudaAsyncBuffer<size_t> perm(this, *p_perm);
|
||||
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, rank);
|
||||
ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims));
|
||||
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));
|
||||
|
||||
ORT_RETURN_IF_ERROR(input_strides.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(perm.CopyToGpu());
|
||||
ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu());
|
||||
|
||||
TransposeImpl(
|
||||
rank,
|
||||
input_strides.GpuPtr(),
|
||||
perm.GpuPtr(),
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(X.template Data<T>()),
|
||||
fdm_output_strides.GpuPtr(),
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(Y->template MutableData<T>()),
|
||||
output_shape.Size());
|
||||
|
||||
return Status::OK();
|
||||
return DoTranspose(*this, *p_perm, X, *Y);
|
||||
}
|
||||
|
||||
#define SPECIALIZED_COMPUTE(T) \
|
||||
REGISTER_KERNEL_TYPED(T) \
|
||||
template Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const;
|
||||
|
||||
SPECIALIZED_COMPUTE(float)
|
||||
SPECIALIZED_COMPUTE(double)
|
||||
SPECIALIZED_COMPUTE(MLFloat16)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -12,12 +12,14 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
class Transpose final : public CudaKernel, public TransposeBase {
|
||||
public:
|
||||
Transpose(const OpKernelInfo& info) : CudaKernel(info), TransposeBase(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
static Status DoTranspose(const Transpose& transpose_kernel,
|
||||
const std::vector<size_t>& permutations, const Tensor& input, Tensor& output);
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -23,23 +23,49 @@ __global__ void _TransposeKernel(size_t shape_rank, const int64_t* input_strides
|
|||
output_data[id] = input_data[input_index];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data,
|
||||
const fast_divmod* fdm_output_strides, T* output_data, size_t N) {
|
||||
Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm,
|
||||
const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N) {
|
||||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
|
||||
_TransposeKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
shape_rank, input_strides, perm, input_data,
|
||||
fdm_output_strides, output_data, N);
|
||||
switch (element_size) {
|
||||
case sizeof(int8_t):
|
||||
_TransposeKernel<int8_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
shape_rank, input_strides, perm,
|
||||
reinterpret_cast<const ToCudaType<int8_t>::MappedType*>(input_data),
|
||||
fdm_output_strides,
|
||||
reinterpret_cast<ToCudaType<int8_t>::MappedType*>(output_data),
|
||||
N);
|
||||
break;
|
||||
case sizeof(int16_t):
|
||||
_TransposeKernel<int16_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
shape_rank, input_strides, perm,
|
||||
reinterpret_cast<const ToCudaType<int16_t>::MappedType*>(input_data),
|
||||
fdm_output_strides,
|
||||
reinterpret_cast<ToCudaType<int16_t>::MappedType*>(output_data),
|
||||
N);
|
||||
break;
|
||||
case sizeof(int32_t):
|
||||
_TransposeKernel<int32_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
shape_rank, input_strides, perm,
|
||||
reinterpret_cast<const ToCudaType<int32_t>::MappedType*>(input_data),
|
||||
fdm_output_strides,
|
||||
reinterpret_cast<ToCudaType<int32_t>::MappedType*>(output_data),
|
||||
N);
|
||||
break;
|
||||
case sizeof(int64_t):
|
||||
_TransposeKernel<int64_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
shape_rank, input_strides, perm,
|
||||
reinterpret_cast<const ToCudaType<int64_t>::MappedType*>(input_data),
|
||||
fdm_output_strides,
|
||||
reinterpret_cast<ToCudaType<int64_t>::MappedType*>(output_data),
|
||||
N);
|
||||
break;
|
||||
default:
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
|
||||
element_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define SPECIALIZED_IMPL(T) \
|
||||
template void TransposeImpl<T>(size_t shape_rank, const int64_t* input_strides, const size_t* perm, \
|
||||
const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, \
|
||||
size_t N);
|
||||
|
||||
SPECIALIZED_IMPL(float)
|
||||
SPECIALIZED_IMPL(double)
|
||||
SPECIALIZED_IMPL(half)
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,9 +8,8 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data,
|
||||
const fast_divmod* fdm_output_strides, T* output_data, size_t N);
|
||||
Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm,
|
||||
const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -64,20 +64,91 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) {
|
|||
// Test 2 dimensional transpose, with permutation attribute specified
|
||||
TEST(TransposeOpTest, TwoDim) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<float> input_vals = {
|
||||
1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f};
|
||||
std::vector<float> input_vals = {1.0f, 2.0f, 3.0f,
|
||||
4.0f, 5.0f, 6.0f};
|
||||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
auto expected_vals = {
|
||||
1.0f, 4.0f,
|
||||
2.0f, 5.0f,
|
||||
3.0f, 6.0f};
|
||||
auto expected_vals = {1.0f, 4.0f,
|
||||
2.0f, 5.0f,
|
||||
3.0f, 6.0f};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDim_double) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<double> input_vals = {1.0, 2.0, 3.0,
|
||||
4.0, 5.0, 6.0};
|
||||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<double> expected_vals = {1.0, 4.0,
|
||||
2.0, 5.0,
|
||||
3.0, 6.0};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDim_int32) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<int32_t> input_vals = {1, 2, 3,
|
||||
4, 5, 6};
|
||||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<int32_t> expected_vals = {1, 4,
|
||||
2, 5,
|
||||
3, 6};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDim_int16) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<int16_t> input_vals = {
|
||||
1, 2, 3,
|
||||
4, 5, 6};
|
||||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<int16_t> expected_vals = {
|
||||
1, 4,
|
||||
2, 5,
|
||||
3, 6};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDim_mlfloat16) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<MLFloat16> input_vals;
|
||||
for (uint16_t i = 0; i < 6; ++i)
|
||||
input_vals.push_back(MLFloat16(i));
|
||||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<MLFloat16> expected_vals = {MLFloat16(1), MLFloat16(4),
|
||||
MLFloat16(2), MLFloat16(5),
|
||||
MLFloat16(3), MLFloat16(6)};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDim_int8) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<int8_t> input_vals = {1, 2, 3,
|
||||
4, 5, 6};
|
||||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<int8_t> expected_vals = {1, 4,
|
||||
2, 5,
|
||||
3, 6};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDimStr) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<std::string> input_vals = {
|
||||
|
|
|
|||
Loading…
Reference in a new issue