mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[ROCm] Enable Einsum for inferencing perf (#12360)
* enable einsum * address comments * comments Co-authored-by: Ethan Tao <ettao@microsoft.com>
This commit is contained in:
parent
9c0fa65110
commit
e4bd41fb3b
10 changed files with 477 additions and 3 deletions
|
|
@ -24,8 +24,6 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) {
|
|||
ORT_ENFORCE(output.SizeInBytes() == input.SizeInBytes(),
|
||||
"Einsum op: The candidate output does not match the actual output's shape");
|
||||
// There are no string tensors in Einsum's case - so safely use memcpy
|
||||
// TODO: Currently, triggers copy on stream 0, investigate if we can still do that
|
||||
// *if* the kernel is launched in a different stream
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), input.SizeInBytes(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
static_cast<cudaStream_t>(static_cast<EinsumCudaAssets*>(einsum_cuda_assets)->cuda_ep_->GetComputeStream())));
|
||||
|
|
|
|||
74
onnxruntime/core/providers/rocm/math/einsum.cc
Normal file
74
onnxruntime/core/providers/rocm/math/einsum.cc
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "einsum.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// This function must exist due to the C++ base class constructor needing this to be defined for the vtable, but it is never called.
|
||||
Status Einsum::DeviceCompute(OpKernelContext* /*context*/, const std::vector<const Tensor*>& /*inputs*/,
|
||||
AllocatorPtr /*allocator*/, concurrency::ThreadPool* /*tp*/) const {
|
||||
assert(false);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace rocm {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Einsum,
|
||||
kOnnxDomain,
|
||||
12,
|
||||
kRocmExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", std::vector<MLDataType>{DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<MLFloat16>()}),
|
||||
Einsum);
|
||||
|
||||
Status Einsum::Compute(OpKernelContext* context) const {
|
||||
return onnxruntime::Einsum::Compute(context);
|
||||
}
|
||||
|
||||
Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const Tensor*>& inputs,
|
||||
AllocatorPtr allocator, concurrency::ThreadPool* tp) const {
|
||||
rocblas_handle rocblas_handle = rocm_ep_->PerThreadRocblasHandle();
|
||||
|
||||
EinsumOp::EinsumRocmAssets einsum_rocm_assets(rocblas_handle, rocm_ep_);
|
||||
|
||||
// EinsumComputePreprocessor section -
|
||||
auto einsum_compute_preprocessor = EinsumComputePreprocessor::Create(*einsum_equation_preprocessor_, inputs, allocator,
|
||||
&einsum_rocm_assets);
|
||||
|
||||
einsum_compute_preprocessor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Diagonal,
|
||||
EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose);
|
||||
// Compute all required metadata to be used at Einsum compute time and return error status code if one was generated
|
||||
ORT_RETURN_IF_ERROR(einsum_compute_preprocessor->Run());
|
||||
|
||||
// EinsumComputeProcessor section -
|
||||
if (inputs[0]->IsDataType<float>()) {
|
||||
auto einsum_compute_processor = EinsumTypedComputeProcessor<float>::Create(context, allocator, tp,
|
||||
*einsum_compute_preprocessor,
|
||||
&einsum_rocm_assets);
|
||||
|
||||
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose,
|
||||
EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul<float>,
|
||||
EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum<float>,
|
||||
EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy);
|
||||
return einsum_compute_processor->Run();
|
||||
} else if (inputs[0]->IsDataType<MLFloat16>()) {
|
||||
auto einsum_compute_processor = EinsumTypedComputeProcessor<MLFloat16>::Create(context, allocator, tp,
|
||||
*einsum_compute_preprocessor,
|
||||
&einsum_rocm_assets);
|
||||
|
||||
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose,
|
||||
EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul<MLFloat16>,
|
||||
EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum<MLFloat16>,
|
||||
EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy);
|
||||
return einsum_compute_processor->Run();
|
||||
}
|
||||
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"Einsum op: An implementation for the input type ",
|
||||
inputs[0]->DataType(), " is not supported yet");
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
} // namespace onnxruntime
|
||||
39
onnxruntime/core/providers/rocm/math/einsum.h
Normal file
39
onnxruntime/core/providers/rocm/math/einsum.h
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/cpu/math/einsum.h"
|
||||
#include "einsum_utils/einsum_auxiliary_ops.h"
|
||||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
class Einsum final : public onnxruntime::Einsum {
|
||||
public:
|
||||
Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) {
|
||||
// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method
|
||||
// TODO: Clean up the ROCMExecutionProvider interface to avoid this
|
||||
rocm_ep_ = const_cast<ROCMExecutionProvider*>(
|
||||
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider()));
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
Status DeviceCompute(OpKernelContext* context, const std::vector<const Tensor*>& inputs,
|
||||
AllocatorPtr allocator, concurrency::ThreadPool* tp) const override;
|
||||
|
||||
// Members of Einsum ROCM kernel
|
||||
using onnxruntime::Einsum::einsum_equation_preprocessor_;
|
||||
using onnxruntime::Einsum::equation_;
|
||||
|
||||
// We need to access to the ROCM EP instance to get the rocblas/miopen handles
|
||||
ROCMExecutionProvider* rocm_ep_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
class ThreadPool;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
#include "einsum_auxiliary_ops.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace EinsumOp {
|
||||
|
||||
namespace DeviceHelpers {
|
||||
|
||||
namespace RocmDeviceHelpers {
|
||||
|
||||
// ROCM EP specific Data copy helper
|
||||
Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets) {
|
||||
ORT_ENFORCE(output.SizeInBytes() == input.SizeInBytes(),
|
||||
"Einsum op: The candidate output does not match the actual output's shape");
|
||||
// There are no string tensors in Einsum's case - so safely use memcpy
|
||||
HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), input.SizeInBytes(),
|
||||
hipMemcpyDeviceToDevice,
|
||||
static_cast<hipStream_t>(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetComputeStream())));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ROCM EP specific Transpose helper
|
||||
Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
|
||||
Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets) {
|
||||
return rocm::Transpose::DoTranspose(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetDeviceProp(),
|
||||
static_cast<hipStream_t>(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetComputeStream()),
|
||||
static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocblas_handle_,
|
||||
permutation, input, output, input_shape_override);
|
||||
}
|
||||
|
||||
// ROCM EP specific MatMul helper
|
||||
template <typename T>
|
||||
Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data,
|
||||
size_t left_stride, size_t right_stride, size_t output_stride,
|
||||
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* /*tp*/,
|
||||
void* einsum_rocm_assets) {
|
||||
typedef typename rocm::ToHipType<T>::MappedType HipT;
|
||||
|
||||
HipT one = rocm::ToHipType<T>::FromFloat(1.0f);
|
||||
HipT zero = rocm::ToHipType<T>::FromFloat(0.0f);
|
||||
|
||||
ROCBLAS_RETURN_IF_ERROR(rocblasGemmStridedBatchedHelper(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocblas_handle_,
|
||||
rocblas_operation_none,
|
||||
rocblas_operation_none,
|
||||
static_cast<int>(N),
|
||||
static_cast<int>(M),
|
||||
static_cast<int>(K),
|
||||
&one,
|
||||
reinterpret_cast<const HipT*>(input_2_data),
|
||||
static_cast<int>(N),
|
||||
static_cast<int>(right_stride),
|
||||
reinterpret_cast<const HipT*>(input_1_data),
|
||||
static_cast<int>(K),
|
||||
static_cast<int>(left_stride),
|
||||
&zero,
|
||||
reinterpret_cast<HipT*>(output_data),
|
||||
static_cast<int>(N),
|
||||
static_cast<int>(output_stride),
|
||||
static_cast<int>(num_batches)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ROCM EP specific ReduceSum helper
|
||||
template <typename T>
|
||||
std::unique_ptr<Tensor> ReduceSum(const Tensor& input, gsl::span<const int64_t> reduce_axes,
|
||||
bool keep_dims, AllocatorPtr allocator,
|
||||
const TensorShape* input_shape_override,
|
||||
concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets) {
|
||||
return rocm::ReductionOps::ReduceCompute<T>(*static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_, MIOPEN_REDUCE_TENSOR_ADD,
|
||||
allocator, input, reduce_axes,
|
||||
keep_dims, false, false, false,
|
||||
true, input_shape_override);
|
||||
}
|
||||
|
||||
// ROCM EP specific Diagonal helper
|
||||
std::unique_ptr<Tensor> Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets) {
|
||||
const auto& input_shape = input.Shape();
|
||||
const auto& input_dims = input_shape.GetDims();
|
||||
auto rank = static_cast<int64_t>(input_dims.size());
|
||||
|
||||
ORT_ENFORCE(rank >= 2 && dim_1 != dim_2 && input_dims[dim_1] == input_dims[dim_2],
|
||||
"Cannot parse the diagonal elements along dims ", dim_1, " and ", dim_2, " for input shape ", input_shape);
|
||||
|
||||
int64_t first_dim = -1; // first_dim holds the lesser of dim_1 and dim_2
|
||||
int64_t second_dim = -1; // second_dim holds the greater of dim_1 and dim_2
|
||||
if (dim_1 < dim_2) {
|
||||
first_dim = dim_1;
|
||||
second_dim = dim_2;
|
||||
} else {
|
||||
first_dim = dim_2;
|
||||
second_dim = dim_1;
|
||||
}
|
||||
|
||||
// Make a copy - we are going to mutate the dims
|
||||
TensorShapeVector output_dims = input_shape.AsShapeVector();
|
||||
|
||||
// Remove the dim value in `second_dim` -
|
||||
// The diagonal values are stored along `first_dim`
|
||||
output_dims.erase(output_dims.begin() + second_dim);
|
||||
|
||||
auto output = Tensor::Create(input.DataType(), output_dims, allocator);
|
||||
|
||||
TensorPitches input_strides(input.Shape().GetDims());
|
||||
rocm::TArray<int64_t> gpu_input_strides(input_strides);
|
||||
|
||||
auto output_rank = static_cast<int32_t>(output_dims.size());
|
||||
rocm::TArray<rocm::fast_divmod> gpu_output_strides(output_rank);
|
||||
TensorPitches output_strides(output_dims);
|
||||
for (auto i = 0; i < output_rank; i++) {
|
||||
gpu_output_strides[i] = rocm::fast_divmod(static_cast<int>(output_strides[i]));
|
||||
}
|
||||
|
||||
DiagonalImpl(
|
||||
static_cast<hipStream_t>(static_cast<EinsumRocmAssets*>(einsum_rocm_assets)->rocm_ep_->GetComputeStream()),
|
||||
input.DataRaw(),
|
||||
input.Shape().GetDims().size(),
|
||||
first_dim,
|
||||
second_dim,
|
||||
gpu_input_strides,
|
||||
output->MutableDataRaw(),
|
||||
gpu_output_strides,
|
||||
TensorShape(output_dims).Size(),
|
||||
input.DataType()->Size());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace RocmDeviceHelpers
|
||||
|
||||
} // namespace DeviceHelpers
|
||||
|
||||
// Explicit template instantiations of functions
|
||||
|
||||
// float
|
||||
template Status DeviceHelpers::RocmDeviceHelpers::MatMul<float>(
|
||||
const float* input_1_data, const float* input_2_data, float* output_data,
|
||||
size_t left_stride, size_t right_stride, size_t output_stride,
|
||||
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp,
|
||||
void* einsum_rocm_assets);
|
||||
|
||||
template std::unique_ptr<Tensor> DeviceHelpers::RocmDeviceHelpers::ReduceSum<float>(
|
||||
const Tensor& input, gsl::span<const int64_t> reduce_axes,
|
||||
bool keep_dims, AllocatorPtr allocator,
|
||||
const TensorShape* input_shape_override,
|
||||
concurrency::ThreadPool* tp, void* einsum_rocm_assets);
|
||||
|
||||
// MLFloat16
|
||||
template Status DeviceHelpers::RocmDeviceHelpers::MatMul<MLFloat16>(
|
||||
const MLFloat16* input_1_data, const MLFloat16* input_2_data, MLFloat16* output_data,
|
||||
size_t left_stride, size_t right_stride, size_t output_stride,
|
||||
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp,
|
||||
void* einsum_rocm_assets);
|
||||
|
||||
template std::unique_ptr<Tensor> DeviceHelpers::RocmDeviceHelpers::ReduceSum<MLFloat16>(
|
||||
const Tensor& input, gsl::span<const int64_t> reduce_axes,
|
||||
bool keep_dims, AllocatorPtr allocator,
|
||||
const TensorShape* input_shape_override,
|
||||
concurrency::ThreadPool* tp, void* einsum_rocm_assets);
|
||||
|
||||
} // namespace EinsumOp
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This module hosts implementations and thin wrappers over other onnx operator implementations
|
||||
// that will be called from within the Einsum operator implementation
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h"
|
||||
#include "core/providers/rocm/tensor/transpose.h"
|
||||
#include "core/providers/rocm/reduction/reduction_ops.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "einsum_auxiliary_ops_diagonal.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace EinsumOp {
|
||||
|
||||
// Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow
|
||||
struct EinsumRocmAssets {
|
||||
explicit EinsumRocmAssets(rocblas_handle rocblas_handle,
|
||||
ROCMExecutionProvider* rocm_ep) {
|
||||
rocblas_handle_ = rocblas_handle;
|
||||
rocm_ep_ = rocm_ep;
|
||||
}
|
||||
|
||||
rocblas_handle rocblas_handle_;
|
||||
ROCMExecutionProvider* rocm_ep_;
|
||||
};
|
||||
|
||||
namespace DeviceHelpers {
|
||||
|
||||
// These are ROCM EP specific device helper implementations
|
||||
namespace RocmDeviceHelpers {
|
||||
|
||||
Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
|
||||
Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets);
|
||||
|
||||
Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets);
|
||||
|
||||
template <typename T>
|
||||
Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data,
|
||||
size_t left_stride, size_t right_stride, size_t output_stride,
|
||||
size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp,
|
||||
void* einsum_rocm_assets);
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<Tensor> ReduceSum(const Tensor& input, gsl::span<const int64_t> reduce_axes,
|
||||
bool keep_dims, AllocatorPtr allocator,
|
||||
const TensorShape* input_shape_override,
|
||||
concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets);
|
||||
|
||||
std::unique_ptr<Tensor> Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets);
|
||||
|
||||
} // namespace RocmDeviceHelpers
|
||||
|
||||
} // namespace DeviceHelpers
|
||||
|
||||
} // namespace EinsumOp
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
#include "hip/hip_runtime.h"
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "einsum_auxiliary_ops_diagonal.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
__global__ void _DiagonalKernel(
|
||||
const T* input_data,
|
||||
const int64_t input_rank,
|
||||
const int64_t dim_1,
|
||||
const int64_t dim_2,
|
||||
const TArray<int64_t> input_strides,
|
||||
T* output_data,
|
||||
const TArray<fast_divmod> output_strides,
|
||||
const size_t output_size) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(output_idx, output_size);
|
||||
int dim = 0;
|
||||
int remain = output_idx;
|
||||
size_t input_idx = 0;
|
||||
int64_t current_input_axis = 0;
|
||||
|
||||
// Output's rank is always 1 less than the input's rank
|
||||
for (int i = 0; i < input_rank - 1; ++i) {
|
||||
output_strides[i].divmod(remain, dim, remain);
|
||||
if (i == dim_1) {
|
||||
// Process dim_2 as dim_2 needs to have the same dim value as dim_1
|
||||
// For example: given a tensor of shape [2, 3, 3] and parsing the diagonal along axes `1` and `2`
|
||||
// we need to parse elements in input[j, i, i] (j -> 0 to 1; and i -> 0 to 2)
|
||||
// and place them in output[j, i] and by definition of diagonal parsing dim_1 has to be equal to
|
||||
// dim_2
|
||||
input_idx += input_strides[dim_2] * dim;
|
||||
}
|
||||
input_idx += input_strides[current_input_axis] * dim;
|
||||
|
||||
// Update current_input_axis
|
||||
// If it is dim_2, skip it
|
||||
if (++current_input_axis == dim_2) {
|
||||
++current_input_axis;
|
||||
}
|
||||
}
|
||||
output_data[output_idx] = input_data[input_idx];
|
||||
}
|
||||
|
||||
void DiagonalImpl(
|
||||
hipStream_t stream,
|
||||
const void* input_data,
|
||||
const int64_t input_rank,
|
||||
const int64_t dim_1,
|
||||
const int64_t dim_2,
|
||||
const TArray<int64_t> input_strides,
|
||||
void* output_data,
|
||||
const TArray<fast_divmod> output_strides,
|
||||
const size_t output_size,
|
||||
size_t element_size) {
|
||||
if (output_size > 0) {
|
||||
int blocksPerGrid = static_cast<int>((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock);
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(int32_t):
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel<int32_t>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
|
||||
reinterpret_cast<const ToHipType<int32_t>::MappedType*>(input_data), input_rank, dim_1, dim_2,
|
||||
input_strides, reinterpret_cast<ToHipType<int32_t>::MappedType*>(output_data), output_strides,
|
||||
output_size);
|
||||
break;
|
||||
|
||||
case sizeof(int64_t):
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel<int64_t>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
|
||||
reinterpret_cast<const ToHipType<int64_t>::MappedType*>(input_data), input_rank, dim_1, dim_2,
|
||||
input_strides, reinterpret_cast<ToHipType<int64_t>::MappedType*>(output_data), output_strides,
|
||||
output_size);
|
||||
break;
|
||||
|
||||
case sizeof(int16_t):
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel<half>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
|
||||
reinterpret_cast<const half*>(input_data), input_rank, dim_1, dim_2,
|
||||
input_strides, reinterpret_cast<half*>(output_data), output_strides,
|
||||
output_size);
|
||||
break;
|
||||
|
||||
// Should not hit this as we do not register kernel support for types that will run into this
|
||||
default:
|
||||
ORT_THROW("Einsum Op: Diagonal parsing unsupported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/rocm/shared_inc/rocm_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace rocm {
|
||||
|
||||
void DiagonalImpl(
|
||||
hipStream_t stream,
|
||||
const void* input_data,
|
||||
const int64_t input_rank,
|
||||
const int64_t dim_1,
|
||||
const int64_t dim_2,
|
||||
const TArray<int64_t> input_strides,
|
||||
void* output_data,
|
||||
const TArray<fast_divmod> output_strides,
|
||||
const size_t output_size,
|
||||
size_t element_size);
|
||||
|
||||
} // namespace rocm
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1773,7 +1773,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum)>,
|
||||
|
||||
// OpSet 13
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
|
||||
|
|
|
|||
|
|
@ -352,6 +352,8 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
// ROCm doesn't support double
|
||||
#ifndef USE_ROCM
|
||||
TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "iji->ji");
|
||||
|
|
@ -359,6 +361,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) {
|
|||
test.AddOutput<double>("o", {2, 2}, {1., 2., 3., 4.});
|
||||
test.Run();
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_int32) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
|
|
|
|||
|
|
@ -235,6 +235,7 @@ def hipify(src_file_path, dst_file_path):
|
|||
s = s.replace("hipblasDestroy", "rocblas_destroy_handle")
|
||||
s = s.replace("hipblasSetStream", "rocblas_set_stream")
|
||||
s = s.replace("HIPBLAS_OP_T", "rocblas_operation_transpose")
|
||||
s = s.replace("HIPBLAS_OP_N", "rocblas_operation_none")
|
||||
|
||||
s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels")
|
||||
s = s.replace("cudaEvent", "hipEvent")
|
||||
|
|
|
|||
Loading…
Reference in a new issue