Adding fp16 support for Einsum Cuda kernel (#6775)

* checkin einsum fp16 support

* remove unnecessary code

* add tests

* add another test
This commit is contained in:
Ye Wang 2021-02-24 01:15:29 -08:00 committed by GitHub
parent c02ec38f8a
commit 47c8e9ad28
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 232 additions and 4 deletions

View file

@ -456,5 +456,18 @@ template std::unique_ptr<Tensor> ReduceSum<int64_t>(
const std::vector<int64_t>& reduce_axes, AllocatorPtr allocator,
concurrency::ThreadPool* tp, void* einsum_cuda_assets, const DeviceHelpers::ReduceSum<int64_t>& reduce_sum_func);
// MLFloat16
template std::unique_ptr<Tensor> MatMul<MLFloat16>(
const Tensor& input_1, const std::vector<int64_t>& input_shape_1_override,
const Tensor& input_2, const std::vector<int64_t>& input_shape_2_override,
AllocatorPtr allocator, concurrency::ThreadPool* tp, void* einsum_cuda_assets,
const DeviceHelpers::MatMul<MLFloat16>& device_matmul_func);
template std::unique_ptr<Tensor> ReduceSum<MLFloat16>(
const Tensor& input, const std::vector<int64_t>& input_shape_override,
const std::vector<int64_t>& reduce_axes, AllocatorPtr allocator,
concurrency::ThreadPool* tp, void* einsum_cuda_assets,
const DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func);
} // namespace EinsumOp
} // namespace onnxruntime

View file

@ -367,5 +367,6 @@ template class EinsumTypedComputeProcessor<float>;
template class EinsumTypedComputeProcessor<int32_t>;
template class EinsumTypedComputeProcessor<double>;
template class EinsumTypedComputeProcessor<int64_t>;
template class EinsumTypedComputeProcessor<MLFloat16>;
} // namespace onnxruntime

View file

@ -15,7 +15,8 @@ ONNX_OPERATOR_KERNEL_EX(
KernelDefBuilder().TypeConstraint("T",
std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()}),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
Einsum);
Status Einsum::Compute(OpKernelContext* context) const {
@ -59,6 +60,16 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<double>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
return einsum_compute_processor.Run();
} else if (inputs[0]->IsDataType<MLFloat16>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<MLFloat16>(context, allocator, tp,
einsum_compute_preprocessor,
&einsum_cuda_assets);
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul<MLFloat16>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<MLFloat16>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
return einsum_compute_processor.Run();
}
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,

View file

@ -165,6 +165,19 @@ template Tensor DeviceHelpers::CudaDeviceHelpers::ReduceSum<double>(
const TensorShape* input_shape_override,
concurrency::ThreadPool* tp, void* einsum_cuda_assets);
// MLFloat16
template Status DeviceHelpers::CudaDeviceHelpers::MatMul<MLFloat16>(
const MLFloat16* input_1_data, const MLFloat16* input_2_data, MLFloat16* output_data,
size_t left_stride, size_t right_stride, size_t output_stride,
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp,
void* einsum_cuda_assets);
template Tensor DeviceHelpers::CudaDeviceHelpers::ReduceSum<MLFloat16>(
const Tensor& input, const std::vector<int64_t>& reduce_axes,
bool keep_dims, AllocatorPtr allocator,
const TensorShape* input_shape_override,
concurrency::ThreadPool* tp, void* einsum_cuda_assets);
} // namespace EinsumOp
} // namespace onnxruntime

View file

@ -30,8 +30,8 @@ __global__ void _DiagonalKernel(
if (i == dim_1) {
// Process dim_2 as dim_2 needs to have the same dim value as dim_1
// For example: given a tensor of shape [2, 3, 3] and parsing the diagonal along axes `1` and `2`
// we need to parse elements in input[j, i, i] (j -> 0 to 1; and i -> 0 to 2)
// and place them in output[j, i] and by definition of diagonal parsing dim_1 has to be equal to
// we need to parse elements in input[j, i, i] (j -> 0 to 1; and i -> 0 to 2)
// and place them in output[j, i] and by definition of diagonal parsing dim_1 has to be equal to
// dim_2
input_idx += input_strides[dim_2] * dim;
}
@ -75,6 +75,13 @@ void DiagonalImpl(
output_size);
break;
case sizeof(int16_t):
_DiagonalKernel<half><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const half*>(input_data), input_rank, dim_1, dim_2,
input_strides, reinterpret_cast<half*>(output_data), output_strides,
output_size);
break;
// Should not hit this as we do not register kernel support for types that will run into this
default:
ORT_THROW("Einsum Op: Diagonal parsing unsupported");

View file

@ -610,7 +610,7 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
&zero, output_tensor, reinterpret_cast<CudaT*>(output.template MutableData<T>())));
}
}
} else {
} else {
// For ArgMax & ArgMin ops, use the indicies as the output with int64 type
// cudnnReduceTensor has issue if input and output has same size, which will happen if the axis to be reduced has dim value of 1.
// the output is zeros of the output size
@ -928,6 +928,13 @@ template Tensor ReduceCompute<double, CUDNN_REDUCE_TENSOR_NO_INDICES>(
bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp,
bool fast_reduction, const TensorShape* input_shape_override);
template Tensor ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO_INDICES>(
CUDAExecutionProvider& cuda_ep, cudnnReduceTensorOp_t cudnn_reduce_op,
AllocatorPtr allocator,
const Tensor& input, const std::vector<int64_t>& axes,
bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp,
bool fast_reduction, const TensorShape* input_shape_override);
} // namespace ReductionOps
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000

View file

@ -3,6 +3,7 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "core/framework/data_types.h"
#include "core/util/math.h"
@ -519,5 +520,180 @@ TEST(Einsum, ImplicitEinsumAsTensorContraction) {
test.Run();
}
// Test each theme for half support
TEST(Einsum, ExplicitEinsumAsIdentity_1D_input_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "i->i");
std::vector<float> input_x_f = {0.9f, 2.5f, 2.3f, 1.5f, -4.5f};
std::vector<float> output_f = {0.9f, 2.5f, 2.3f, 1.5f, -4.5f};
std::vector<MLFloat16> input_x(5);
std::vector<MLFloat16> output(5);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 5);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 5);
test.AddInput<MLFloat16>("x", {5}, input_x);
test.AddOutput<MLFloat16>("y", {5}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsTransposeOp_2D_input_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ji->ij");
std::vector<float> input_x_f = {1.f, 2.f, 3.f, 4.f};
std::vector<float> output_f = {1.f, 3.f, 2.f, 4.f};
std::vector<MLFloat16> input_x(4);
std::vector<MLFloat16> output(4);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 4);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 4);
test.AddInput<MLFloat16>("x", {2, 2}, input_x);
test.AddOutput<MLFloat16>("y", {2, 2}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsReduceOp_2D_input_0_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ij->i");
std::vector<float> input_x_f = {1.f, 2.f, 3.f, 4.f};
std::vector<float> output_f = {3.f, 7.f};
std::vector<MLFloat16> input_x(4);
std::vector<MLFloat16> output(2);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 4);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 2);
test.AddInput<MLFloat16>("x", {2, 2}, input_x);
test.AddOutput<MLFloat16>("y", {2}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsOuterProductOp_2D_input_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "i,j->ij");
std::vector<float> input_x_f = {1.f, 2.f};
std::vector<float> input_y_f = {3.f, 4.f};
std::vector<float> output_f = {3.f, 4.f, 6.f, 8.f};
std::vector<MLFloat16> input_x(2);
std::vector<MLFloat16> input_y(2);
std::vector<MLFloat16> output(4);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 2);
ConvertFloatToMLFloat16(input_y_f.data(), input_y.data(), 2);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 4);
test.AddInput<MLFloat16>("x", {2}, input_x);
test.AddInput<MLFloat16>("y", {2}, input_y);
test.AddOutput<MLFloat16>("o", {2, 2}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsMatmul_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ij,jk->ik");
std::vector<float> input_x_f = {1.f, 2.f, 3.f, 4.f};
std::vector<float> input_y_f = {1.f, 2.f, 3.f, 4.f};
std::vector<float> output_f = {7.f, 10.f, 15.f, 22.f};
std::vector<MLFloat16> input_x(4);
std::vector<MLFloat16> input_y(4);
std::vector<MLFloat16> output(4);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 4);
ConvertFloatToMLFloat16(input_y_f.data(), input_y.data(), 4);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 4);
test.AddInput<MLFloat16>("x", {2, 2}, input_x);
test.AddInput<MLFloat16>("y", {2, 2}, input_y);
test.AddOutput<MLFloat16>("o", {2, 2}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsBatchedMatmul_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "bij,bjk->bik");
std::vector<float> input_x_f = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f};
std::vector<float> input_y_f = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f};
std::vector<float> output_f = {7.f, 10.f, 15.f, 22.f, 7.f, 10.f, 15.f, 22.f};
std::vector<MLFloat16> input_x(8);
std::vector<MLFloat16> input_y(8);
std::vector<MLFloat16> output(8);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 8);
ConvertFloatToMLFloat16(input_y_f.data(), input_y.data(), 8);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 8);
test.AddInput<MLFloat16>("x", {2, 2, 2}, input_x);
test.AddInput<MLFloat16>("y", {2, 2, 2}, input_y);
test.AddOutput<MLFloat16>("o", {2, 2, 2}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsDiagonalOp_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ii->i");
std::vector<float> input_x_f = {1.f, 2.f, 3.f, 4.f};
std::vector<float> output_f = {1.f, 4.f};
std::vector<MLFloat16> input_x(4);
std::vector<MLFloat16> output(2);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 4);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 2);
test.AddInput<MLFloat16>("x", {2, 2}, input_x);
test.AddOutput<MLFloat16>("o", {2}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithOneScalar_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", ",...i->...i");
std::vector<float> input_x_f = {10.f};
std::vector<float> input_y_f = {1.f, 2.f, 3.f, 4.f};
std::vector<float> output_f = {10.f, 20.f, 30.f, 40.f};
std::vector<MLFloat16> input_x(1);
std::vector<MLFloat16> input_y(4);
std::vector<MLFloat16> output(4);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 1);
ConvertFloatToMLFloat16(input_y_f.data(), input_y.data(), 4);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 4);
test.AddInput<MLFloat16>("x", {}, input_x);
test.AddInput<MLFloat16>("y", {2, 2}, input_y);
test.AddOutput<MLFloat16>("o", {2, 2}, output);
test.Run();
}
TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) {
if (!HasCudaEnvironment(600)) {
return;
}
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "abcd,ea->bcde");
std::vector<float> input_x_f = {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f};
std::vector<float> input_y_f = {1.f, 2.f, 1.f, 2.f};
std::vector<float> output_f = {3.f, 3.f, 6.f, 6.f, 3.f, 3.f, 6.f, 6.f, 3.f, 3.f, 6.f, 6.f, 3.f, 3.f, 6.f, 6.f};
std::vector<MLFloat16> input_x(16);
std::vector<MLFloat16> input_y(4);
std::vector<MLFloat16> output(16);
ConvertFloatToMLFloat16(input_x_f.data(), input_x.data(), 16);
ConvertFloatToMLFloat16(input_y_f.data(), input_y.data(), 4);
ConvertFloatToMLFloat16(output_f.data(), output.data(), 16);
test.AddInput<MLFloat16>("x", {2, 2, 2, 2}, input_x);
test.AddInput<MLFloat16>("y", {2, 2}, input_y);
test.AddOutput<MLFloat16>("o", {2, 2, 2, 2}, output);
test.Run();
}
} // namespace test
} // namespace onnxruntime