[ROCm] Enable Einsum for inferencing perf (#12360)

* enable einsum

* address comments

* comments

Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
ytaous 2022-07-28 20:26:25 -07:00 committed by GitHub
parent 9c0fa65110
commit e4bd41fb3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 477 additions and 3 deletions

View file

@ -24,8 +24,6 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) {
ORT_ENFORCE(output.SizeInBytes() == input.SizeInBytes(),
"Einsum op: The candidate output does not match the actual output's shape");
// There are no string tensors in Einsum's case - so safely use memcpy
// TODO: Currently, triggers copy on stream 0, investigate if we can still do that
// *if* the kernel is launched in a different stream
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), input.SizeInBytes(),
cudaMemcpyDeviceToDevice,
static_cast<cudaStream_t>(static_cast<EinsumCudaAssets*>(einsum_cuda_assets)->cuda_ep_->GetComputeStream())));

View file

@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "einsum.h"
namespace onnxruntime {
// This function must exist due to the C++ base class constructor needing this to be defined for the vtable, but it is never called.
Status Einsum::DeviceCompute(OpKernelContext* /*context*/, const std::vector<const Tensor*>& /*inputs*/,
AllocatorPtr /*allocator*/, concurrency::ThreadPool* /*tp*/) const {
assert(false);
return Status::OK();
}
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
Einsum,
kOnnxDomain,
12,
kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<MLFloat16>()}),
Einsum);
Status Einsum::Compute(OpKernelContext* context) const {
return onnxruntime::Einsum::Compute(context);
}
Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const Tensor*>& inputs,
AllocatorPtr allocator, concurrency::ThreadPool* tp) const {
rocblas_handle rocblas_handle = rocm_ep_->PerThreadRocblasHandle();
EinsumOp::EinsumRocmAssets einsum_rocm_assets(rocblas_handle, rocm_ep_);
// EinsumComputePreprocessor section -
auto einsum_compute_preprocessor = EinsumComputePreprocessor::Create(*einsum_equation_preprocessor_, inputs, allocator,
&einsum_rocm_assets);
einsum_compute_preprocessor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Diagonal,
EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose);
// Compute all required metadata to be used at Einsum compute time and return error status code if one was generated
ORT_RETURN_IF_ERROR(einsum_compute_preprocessor->Run());
// EinsumComputeProcessor section -
if (inputs[0]->IsDataType<float>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<float>::Create(context, allocator, tp,
*einsum_compute_preprocessor,
&einsum_rocm_assets);
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul<float>,
EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum<float>,
EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy);
return einsum_compute_processor->Run();
} else if (inputs[0]->IsDataType<MLFloat16>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<MLFloat16>::Create(context, allocator, tp,
*einsum_compute_preprocessor,
&einsum_rocm_assets);
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul<MLFloat16>,
EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum<MLFloat16>,
EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy);
return einsum_compute_processor->Run();
}
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Einsum op: An implementation for the input type ",
inputs[0]->DataType(), " is not supported yet");
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/platform/threadpool.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/cpu/math/einsum.h"
#include "einsum_utils/einsum_auxiliary_ops.h"
#include "core/providers/rocm/rocm_execution_provider.h"
namespace onnxruntime {
namespace rocm {
class Einsum final : public onnxruntime::Einsum {
public:
Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) {
// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method
// TODO: Clean up the ROCMExecutionProvider interface to avoid this
rocm_ep_ = const_cast<ROCMExecutionProvider*>(
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider()));
}
Status Compute(OpKernelContext* context) const override;
private:
Status DeviceCompute(OpKernelContext* context, const std::vector<const Tensor*>& inputs,
AllocatorPtr allocator, concurrency::ThreadPool* tp) const override;
// Members of Einsum ROCM kernel
using onnxruntime::Einsum::einsum_equation_preprocessor_;
using onnxruntime::Einsum::equation_;
// We need to access to the ROCM EP instance to get the rocblas/miopen handles
ROCMExecutionProvider* rocm_ep_;
};
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,175 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
namespace onnxruntime {
namespace concurrency {
class ThreadPool;
}
} // namespace onnxruntime
#include "einsum_auxiliary_ops.h"
namespace onnxruntime {
namespace EinsumOp {
namespace DeviceHelpers {
namespace RocmDeviceHelpers {
// ROCM EP specific Data copy helper
Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets) {
ORT_ENFORCE(output.SizeInBytes() == input.SizeInBytes(),
"Einsum op: The candidate output does not match the actual output's shape");
// There are no string tensors in Einsum's case - so safely use memcpy
HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), input.SizeInBytes(),
hipMemcpyDeviceToDevice,
static_cast<hipStream_t>(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetComputeStream())));
return Status::OK();
}
// ROCM EP specific Transpose helper
Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets) {
return rocm::Transpose::DoTranspose(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetDeviceProp(),
static_cast<hipStream_t>(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetComputeStream()),
static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocblas_handle_,
permutation, input, output, input_shape_override);
}
// ROCM EP specific MatMul helper
template <typename T>
Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data,
size_t left_stride, size_t right_stride, size_t output_stride,
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* /*tp*/,
void* einsum_rocm_assets) {
typedef typename rocm::ToHipType<T>::MappedType HipT;
HipT one = rocm::ToHipType<T>::FromFloat(1.0f);
HipT zero = rocm::ToHipType<T>::FromFloat(0.0f);
ROCBLAS_RETURN_IF_ERROR(rocblasGemmStridedBatchedHelper(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocblas_handle_,
rocblas_operation_none,
rocblas_operation_none,
static_cast<int>(N),
static_cast<int>(M),
static_cast<int>(K),
&one,
reinterpret_cast<const HipT*>(input_2_data),
static_cast<int>(N),
static_cast<int>(right_stride),
reinterpret_cast<const HipT*>(input_1_data),
static_cast<int>(K),
static_cast<int>(left_stride),
&zero,
reinterpret_cast<HipT*>(output_data),
static_cast<int>(N),
static_cast<int>(output_stride),
static_cast<int>(num_batches)));
return Status::OK();
}
// ROCM EP specific ReduceSum helper
template <typename T>
std::unique_ptr<Tensor> ReduceSum(const Tensor& input, gsl::span<const int64_t> reduce_axes,
bool keep_dims, AllocatorPtr allocator,
const TensorShape* input_shape_override,
concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets) {
return rocm::ReductionOps::ReduceCompute<T>(*static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_, MIOPEN_REDUCE_TENSOR_ADD,
allocator, input, reduce_axes,
keep_dims, false, false, false,
true, input_shape_override);
}
// ROCM EP specific Diagonal helper
std::unique_ptr<Tensor> Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets) {
const auto& input_shape = input.Shape();
const auto& input_dims = input_shape.GetDims();
auto rank = static_cast<int64_t>(input_dims.size());
ORT_ENFORCE(rank >= 2 && dim_1 != dim_2 && input_dims[dim_1] == input_dims[dim_2],
"Cannot parse the diagonal elements along dims ", dim_1, " and ", dim_2, " for input shape ", input_shape);
int64_t first_dim = -1; // first_dim holds the lesser of dim_1 and dim_2
int64_t second_dim = -1; // second_dim holds the greater of dim_1 and dim_2
if (dim_1 < dim_2) {
first_dim = dim_1;
second_dim = dim_2;
} else {
first_dim = dim_2;
second_dim = dim_1;
}
// Make a copy - we are going to mutate the dims
TensorShapeVector output_dims = input_shape.AsShapeVector();
// Remove the dim value in `second_dim` -
// The diagonal values are stored along `first_dim`
output_dims.erase(output_dims.begin() + second_dim);
auto output = Tensor::Create(input.DataType(), output_dims, allocator);
TensorPitches input_strides(input.Shape().GetDims());
rocm::TArray<int64_t> gpu_input_strides(input_strides);
auto output_rank = static_cast<int32_t>(output_dims.size());
rocm::TArray<rocm::fast_divmod> gpu_output_strides(output_rank);
TensorPitches output_strides(output_dims);
for (auto i = 0; i < output_rank; i++) {
gpu_output_strides[i] = rocm::fast_divmod(static_cast<int>(output_strides[i]));
}
DiagonalImpl(
static_cast<hipStream_t>(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetComputeStream()),
input.DataRaw(),
input.Shape().GetDims().size(),
first_dim,
second_dim,
gpu_input_strides,
output->MutableDataRaw(),
gpu_output_strides,
TensorShape(output_dims).Size(),
input.DataType()->Size());
return output;
}
} // namespace RocmDeviceHelpers
} // namespace DeviceHelpers
// Explicit template instantiations of functions
// float
template Status DeviceHelpers::RocmDeviceHelpers::MatMul<float>(
const float* input_1_data, const float* input_2_data, float* output_data,
size_t left_stride, size_t right_stride, size_t output_stride,
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp,
void* einsum_rocm_assets);
template std::unique_ptr<Tensor> DeviceHelpers::RocmDeviceHelpers::ReduceSum<float>(
const Tensor& input, gsl::span<const int64_t> reduce_axes,
bool keep_dims, AllocatorPtr allocator,
const TensorShape* input_shape_override,
concurrency::ThreadPool* tp, void* einsum_rocm_assets);
// MLFloat16
template Status DeviceHelpers::RocmDeviceHelpers::MatMul<MLFloat16>(
const MLFloat16* input_1_data, const MLFloat16* input_2_data, MLFloat16* output_data,
size_t left_stride, size_t right_stride, size_t output_stride,
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp,
void* einsum_rocm_assets);
template std::unique_ptr<Tensor> DeviceHelpers::RocmDeviceHelpers::ReduceSum<MLFloat16>(
const Tensor& input, gsl::span<const int64_t> reduce_axes,
bool keep_dims, AllocatorPtr allocator,
const TensorShape* input_shape_override,
concurrency::ThreadPool* tp, void* einsum_rocm_assets);
} // namespace EinsumOp
} // namespace onnxruntime

View file

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This module hosts implementations and thin wrappers over other onnx operator implementations
// that will be called from within the Einsum operator implementation
#pragma once
#include "core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h"
#include "core/providers/rocm/tensor/transpose.h"
#include "core/providers/rocm/reduction/reduction_ops.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/providers/cpu/tensor/utils.h"
#include "einsum_auxiliary_ops_diagonal.h"
#include "core/providers/rocm/rocm_common.h"
namespace onnxruntime {
namespace EinsumOp {
// Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow
struct EinsumRocmAssets {
explicit EinsumRocmAssets(rocblas_handle rocblas_handle,
ROCMExecutionProvider* rocm_ep) {
rocblas_handle_ = rocblas_handle;
rocm_ep_ = rocm_ep;
}
rocblas_handle rocblas_handle_;
ROCMExecutionProvider* rocm_ep_;
};
namespace DeviceHelpers {
// These are ROCM EP specific device helper implementations
namespace RocmDeviceHelpers {
Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets);
Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets);
template <typename T>
Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data,
size_t left_stride, size_t right_stride, size_t output_stride,
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp,
void* einsum_rocm_assets);
template <typename T>
std::unique_ptr<Tensor> ReduceSum(const Tensor& input, gsl::span<const int64_t> reduce_axes,
bool keep_dims, AllocatorPtr allocator,
const TensorShape* input_shape_override,
concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets);
std::unique_ptr<Tensor> Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets);
} // namespace RocmDeviceHelpers
} // namespace DeviceHelpers
} // namespace EinsumOp
} // namespace onnxruntime

View file

@ -0,0 +1,95 @@
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/cu_inc/common.cuh"
#include "einsum_auxiliary_ops_diagonal.h"
namespace onnxruntime {
namespace rocm {
template <typename T>
__global__ void _DiagonalKernel(
const T* input_data,
const int64_t input_rank,
const int64_t dim_1,
const int64_t dim_2,
const TArray<int64_t> input_strides,
T* output_data,
const TArray<fast_divmod> output_strides,
const size_t output_size) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(output_idx, output_size);
int dim = 0;
int remain = output_idx;
size_t input_idx = 0;
int64_t current_input_axis = 0;
// Output's rank is always 1 less than the input's rank
for (int i = 0; i < input_rank - 1; ++i) {
output_strides[i].divmod(remain, dim, remain);
if (i == dim_1) {
// Process dim_2 as dim_2 needs to have the same dim value as dim_1
// For example: given a tensor of shape [2, 3, 3] and parsing the diagonal along axes `1` and `2`
// we need to parse elements in input[j, i, i] (j -> 0 to 1; and i -> 0 to 2)
// and place them in output[j, i] and by definition of diagonal parsing dim_1 has to be equal to
// dim_2
input_idx += input_strides[dim_2] * dim;
}
input_idx += input_strides[current_input_axis] * dim;
// Update current_input_axis
// If it is dim_2, skip it
if (++current_input_axis == dim_2) {
++current_input_axis;
}
}
output_data[output_idx] = input_data[input_idx];
}
void DiagonalImpl(
hipStream_t stream,
const void* input_data,
const int64_t input_rank,
const int64_t dim_1,
const int64_t dim_2,
const TArray<int64_t> input_strides,
void* output_data,
const TArray<fast_divmod> output_strides,
const size_t output_size,
size_t element_size) {
if (output_size > 0) {
int blocksPerGrid = static_cast<int>((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock);
switch (element_size) {
case sizeof(int32_t):
hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel<int32_t>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
reinterpret_cast<const ToHipType<int32_t>::MappedType*>(input_data), input_rank, dim_1, dim_2,
input_strides, reinterpret_cast<ToHipType<int32_t>::MappedType*>(output_data), output_strides,
output_size);
break;
case sizeof(int64_t):
hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel<int64_t>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
reinterpret_cast<const ToHipType<int64_t>::MappedType*>(input_data), input_rank, dim_1, dim_2,
input_strides, reinterpret_cast<ToHipType<int64_t>::MappedType*>(output_data), output_strides,
output_size);
break;
case sizeof(int16_t):
hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel<half>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
reinterpret_cast<const half*>(input_data), input_rank, dim_1, dim_2,
input_strides, reinterpret_cast<half*>(output_data), output_strides,
output_size);
break;
// Should not hit this as we do not register kernel support for types that will run into this
default:
ORT_THROW("Einsum Op: Diagonal parsing unsupported");
}
}
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace rocm {
void DiagonalImpl(
hipStream_t stream,
const void* input_data,
const int64_t input_rank,
const int64_t dim_1,
const int64_t dim_2,
const TArray<int64_t> input_strides,
void* output_data,
const TArray<fast_divmod> output_strides,
const size_t output_size,
size_t element_size);
} // namespace rocm
} // namespace onnxruntime

View file

@ -1773,7 +1773,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum)>,
// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow)>,

View file

@ -352,6 +352,8 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose) {
test.Run();
}
// ROCm doesn't support double
#ifndef USE_ROCM
TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "iji->ji");
@ -359,6 +361,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) {
test.AddOutput<double>("o", {2, 2}, {1., 2., 3., 4.});
test.Run();
}
#endif
TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_int32) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);

View file

@ -235,6 +235,7 @@ def hipify(src_file_path, dst_file_path):
s = s.replace("hipblasDestroy", "rocblas_destroy_handle")
s = s.replace("hipblasSetStream", "rocblas_set_stream")
s = s.replace("HIPBLAS_OP_T", "rocblas_operation_transpose")
s = s.replace("HIPBLAS_OP_N", "rocblas_operation_none")
s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels")
s = s.replace("cudaEvent", "hipEvent")