diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index 033217df99..117e562068 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -24,8 +24,6 @@ Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets) { ORT_ENFORCE(output.SizeInBytes() == input.SizeInBytes(), "Einsum op: The candidate output does not match the actual output's shape"); // There are no string tensors in Einsum's case - so safely use memcpy - // TODO: Currently, triggers copy on stream 0, investigate if we can still do that - // *if* the kernel is launched in a different stream CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), input.SizeInBytes(), cudaMemcpyDeviceToDevice, static_cast(static_cast(einsum_cuda_assets)->cuda_ep_->GetComputeStream()))); diff --git a/onnxruntime/core/providers/rocm/math/einsum.cc b/onnxruntime/core/providers/rocm/math/einsum.cc new file mode 100644 index 0000000000..00e56b89c4 --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/einsum.cc @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "einsum.h" + +namespace onnxruntime { + +// This function must exist due to the C++ base class constructor needing this to be defined for the vtable, but it is never called. +Status Einsum::DeviceCompute(OpKernelContext* /*context*/, const std::vector& /*inputs*/, + AllocatorPtr /*allocator*/, concurrency::ThreadPool* /*tp*/) const { + assert(false); + return Status::OK(); +} + +namespace rocm { + +ONNX_OPERATOR_KERNEL_EX( + Einsum, + kOnnxDomain, + 12, + kRocmExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), + Einsum); + +Status Einsum::Compute(OpKernelContext* context) const { + return onnxruntime::Einsum::Compute(context); +} + +Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector& inputs, + AllocatorPtr allocator, concurrency::ThreadPool* tp) const { + rocblas_handle rocblas_handle = rocm_ep_->PerThreadRocblasHandle(); + + EinsumOp::EinsumRocmAssets einsum_rocm_assets(rocblas_handle, rocm_ep_); + + // EinsumComputePreprocessor section - + auto einsum_compute_preprocessor = EinsumComputePreprocessor::Create(*einsum_equation_preprocessor_, inputs, allocator, + &einsum_rocm_assets); + + einsum_compute_preprocessor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Diagonal, + EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose); + // Compute all required metadata to be used at Einsum compute time and return error status code if one was generated + ORT_RETURN_IF_ERROR(einsum_compute_preprocessor->Run()); + + // EinsumComputeProcessor section - + if (inputs[0]->IsDataType()) { + auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, + *einsum_compute_preprocessor, + &einsum_rocm_assets); + + einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose, + EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul, + EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum, + EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy); + return einsum_compute_processor->Run(); + } else if (inputs[0]->IsDataType()) { + auto einsum_compute_processor = EinsumTypedComputeProcessor::Create(context, allocator, tp, + *einsum_compute_preprocessor, + &einsum_rocm_assets); + + einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::RocmDeviceHelpers::Transpose, + EinsumOp::DeviceHelpers::RocmDeviceHelpers::MatMul, + EinsumOp::DeviceHelpers::RocmDeviceHelpers::ReduceSum, + EinsumOp::DeviceHelpers::RocmDeviceHelpers::DataCopy); + return einsum_compute_processor->Run(); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Einsum op: An implementation for the input type ", + inputs[0]->DataType(), " is not supported yet"); +} + +} // namespace rocm + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum.h b/onnxruntime/core/providers/rocm/math/einsum.h new file mode 100644 index 0000000000..a4adc3da98 --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/einsum.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/platform/threadpool.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/cpu/math/einsum.h" +#include "einsum_utils/einsum_auxiliary_ops.h" +#include "core/providers/rocm/rocm_execution_provider.h" + +namespace onnxruntime { +namespace rocm { + +class Einsum final : public onnxruntime::Einsum { + public: + Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) { + // We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method + // TODO: Clean up the ROCMExecutionProvider interface to avoid this + rocm_ep_ = const_cast( + static_cast(info.GetExecutionProvider())); + } + + Status Compute(OpKernelContext* context) const override; + + private: + Status DeviceCompute(OpKernelContext* context, const std::vector& inputs, + AllocatorPtr allocator, concurrency::ThreadPool* tp) const override; + + // Members of Einsum ROCM kernel + using onnxruntime::Einsum::einsum_equation_preprocessor_; + using onnxruntime::Einsum::equation_; + + // We need to access to the ROCM EP instance to get the rocblas/miopen handles + ROCMExecutionProvider* rocm_ep_; +}; + +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc new file mode 100644 index 0000000000..3922617e38 --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.cc @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace concurrency { +class ThreadPool; +} +} // namespace onnxruntime + +#include "einsum_auxiliary_ops.h" + +namespace onnxruntime { + +namespace EinsumOp { + +namespace DeviceHelpers { + +namespace RocmDeviceHelpers { + +// ROCM EP specific Data copy helper +Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets) { + ORT_ENFORCE(output.SizeInBytes() == input.SizeInBytes(), + "Einsum op: The candidate output does not match the actual output's shape"); + // There are no string tensors in Einsum's case - so safely use memcpy + HIP_RETURN_IF_ERROR(hipMemcpyAsync(output.MutableDataRaw(), input.DataRaw(), input.SizeInBytes(), + hipMemcpyDeviceToDevice, + static_cast(static_cast(einsum_rocm_assets)->rocm_ep_->GetComputeStream()))); + + return Status::OK(); +} + +// ROCM EP specific Transpose helper +Status Transpose(const gsl::span& permutation, const Tensor& input, + Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets) { + return rocm::Transpose::DoTranspose(static_cast(einsum_rocm_assets)->rocm_ep_->GetDeviceProp(), + static_cast(static_cast(einsum_rocm_assets)->rocm_ep_->GetComputeStream()), + static_cast(einsum_rocm_assets)->rocblas_handle_, + permutation, input, output, input_shape_override); +} + +// ROCM EP specific MatMul helper +template +Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, + size_t left_stride, size_t right_stride, size_t output_stride, + size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* /*tp*/, + void* einsum_rocm_assets) { + typedef typename rocm::ToHipType::MappedType HipT; + + HipT one = rocm::ToHipType::FromFloat(1.0f); + HipT zero = rocm::ToHipType::FromFloat(0.0f); + + ROCBLAS_RETURN_IF_ERROR(rocblasGemmStridedBatchedHelper(static_cast(einsum_rocm_assets)->rocblas_handle_, + rocblas_operation_none, + rocblas_operation_none, + static_cast(N), + static_cast(M), + static_cast(K), + &one, + reinterpret_cast(input_2_data), + static_cast(N), + static_cast(right_stride), + reinterpret_cast(input_1_data), + static_cast(K), + static_cast(left_stride), + &zero, + reinterpret_cast(output_data), + static_cast(N), + static_cast(output_stride), + static_cast(num_batches))); + + return Status::OK(); +} + +// ROCM EP specific ReduceSum helper +template +std::unique_ptr ReduceSum(const Tensor& input, gsl::span reduce_axes, + bool keep_dims, AllocatorPtr allocator, + const TensorShape* input_shape_override, + concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets) { + return rocm::ReductionOps::ReduceCompute(*static_cast(einsum_rocm_assets)->rocm_ep_, MIOPEN_REDUCE_TENSOR_ADD, + allocator, input, reduce_axes, + keep_dims, false, false, false, + true, input_shape_override); +} + +// ROCM EP specific Diagonal helper +std::unique_ptr Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets) { + const auto& input_shape = input.Shape(); + const auto& input_dims = input_shape.GetDims(); + auto rank = static_cast(input_dims.size()); + + ORT_ENFORCE(rank >= 2 && dim_1 != dim_2 && input_dims[dim_1] == input_dims[dim_2], + "Cannot parse the diagonal elements along dims ", dim_1, " and ", dim_2, " for input shape ", input_shape); + + int64_t first_dim = -1; // first_dim holds the lesser of dim_1 and dim_2 + int64_t second_dim = -1; // second_dim holds the greater of dim_1 and dim_2 + if (dim_1 < dim_2) { + first_dim = dim_1; + second_dim = dim_2; + } else { + first_dim = dim_2; + second_dim = dim_1; + } + + // Make a copy - we are going to mutate the dims + TensorShapeVector output_dims = input_shape.AsShapeVector(); + + // Remove the dim value in `second_dim` - + // The diagonal values are stored along `first_dim` + output_dims.erase(output_dims.begin() + second_dim); + + auto output = Tensor::Create(input.DataType(), output_dims, allocator); + + TensorPitches input_strides(input.Shape().GetDims()); + rocm::TArray gpu_input_strides(input_strides); + + auto output_rank = static_cast(output_dims.size()); + rocm::TArray gpu_output_strides(output_rank); + TensorPitches output_strides(output_dims); + for (auto i = 0; i < output_rank; i++) { + gpu_output_strides[i] = rocm::fast_divmod(static_cast(output_strides[i])); + } + + DiagonalImpl( + static_cast(static_cast(einsum_rocm_assets)->rocm_ep_->GetComputeStream()), + input.DataRaw(), + input.Shape().GetDims().size(), + first_dim, + second_dim, + gpu_input_strides, + output->MutableDataRaw(), + gpu_output_strides, + TensorShape(output_dims).Size(), + input.DataType()->Size()); + + return output; +} + +} // namespace RocmDeviceHelpers + +} // namespace DeviceHelpers + +// Explicit template instantiations of functions + +// float +template Status DeviceHelpers::RocmDeviceHelpers::MatMul( + const float* input_1_data, const float* input_2_data, float* output_data, + size_t left_stride, size_t right_stride, size_t output_stride, + size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp, + void* einsum_rocm_assets); + +template std::unique_ptr DeviceHelpers::RocmDeviceHelpers::ReduceSum( + const Tensor& input, gsl::span reduce_axes, + bool keep_dims, AllocatorPtr allocator, + const TensorShape* input_shape_override, + concurrency::ThreadPool* tp, void* einsum_rocm_assets); + +// MLFloat16 +template Status DeviceHelpers::RocmDeviceHelpers::MatMul( + const MLFloat16* input_1_data, const MLFloat16* input_2_data, MLFloat16* output_data, + size_t left_stride, size_t right_stride, size_t output_stride, + size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp, + void* einsum_rocm_assets); + +template std::unique_ptr DeviceHelpers::RocmDeviceHelpers::ReduceSum( + const Tensor& input, gsl::span reduce_axes, + bool keep_dims, AllocatorPtr allocator, + const TensorShape* input_shape_override, + concurrency::ThreadPool* tp, void* einsum_rocm_assets); + +} // namespace EinsumOp + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h new file mode 100644 index 0000000000..0415e6ac46 --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This module hosts implementations and thin wrappers over other onnx operator implementations +// that will be called from within the Einsum operator implementation + +#pragma once + +#include "core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.h" +#include "core/providers/rocm/tensor/transpose.h" +#include "core/providers/rocm/reduction/reduction_ops.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/providers/cpu/tensor/utils.h" +#include "einsum_auxiliary_ops_diagonal.h" +#include "core/providers/rocm/rocm_common.h" + +namespace onnxruntime { + +namespace EinsumOp { + +// Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow +struct EinsumRocmAssets { + explicit EinsumRocmAssets(rocblas_handle rocblas_handle, + ROCMExecutionProvider* rocm_ep) { + rocblas_handle_ = rocblas_handle; + rocm_ep_ = rocm_ep; + } + + rocblas_handle rocblas_handle_; + ROCMExecutionProvider* rocm_ep_; +}; + +namespace DeviceHelpers { + +// These are ROCM EP specific device helper implementations +namespace RocmDeviceHelpers { + +Status Transpose(const gsl::span& permutation, const Tensor& input, + Tensor& output, const TensorShape* input_shape_override, void* einsum_rocm_assets); + +Status DataCopy(const Tensor& input, Tensor& output, void* einsum_rocm_assets); + +template +Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, + size_t left_stride, size_t right_stride, size_t output_stride, + size_t num_batches, size_t M, size_t K, size_t N, concurrency::ThreadPool* tp, + void* einsum_rocm_assets); + +template +std::unique_ptr ReduceSum(const Tensor& input, gsl::span reduce_axes, + bool keep_dims, AllocatorPtr allocator, + const TensorShape* input_shape_override, + concurrency::ThreadPool* /*tp*/, void* einsum_rocm_assets); + +std::unique_ptr Diagonal(const Tensor& input, int64_t dim_1, int64_t dim_2, AllocatorPtr allocator, void* einsum_rocm_assets); + +} // namespace RocmDeviceHelpers + +} // namespace DeviceHelpers + +} // namespace EinsumOp + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu new file mode 100644 index 0000000000..7b8630568a --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu @@ -0,0 +1,95 @@ +#include "hip/hip_runtime.h" +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/rocm/cu_inc/common.cuh" +#include "einsum_auxiliary_ops_diagonal.h" + +namespace onnxruntime { + +namespace rocm { + +template +__global__ void _DiagonalKernel( + const T* input_data, + const int64_t input_rank, + const int64_t dim_1, + const int64_t dim_2, + const TArray input_strides, + T* output_data, + const TArray output_strides, + const size_t output_size) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(output_idx, output_size); + int dim = 0; + int remain = output_idx; + size_t input_idx = 0; + int64_t current_input_axis = 0; + + // Output's rank is always 1 less than the input's rank + for (int i = 0; i < input_rank - 1; ++i) { + output_strides[i].divmod(remain, dim, remain); + if (i == dim_1) { + // Process dim_2 as dim_2 needs to have the same dim value as dim_1 + // For example: given a tensor of shape [2, 3, 3] and parsing the diagonal along axes `1` and `2` + // we need to parse elements in input[j, i, i] (j -> 0 to 1; and i -> 0 to 2) + // and place them in output[j, i] and by definition of diagonal parsing dim_1 has to be equal to + // dim_2 + input_idx += input_strides[dim_2] * dim; + } + input_idx += input_strides[current_input_axis] * dim; + + // Update current_input_axis + // If it is dim_2, skip it + if (++current_input_axis == dim_2) { + ++current_input_axis; + } + } + output_data[output_idx] = input_data[input_idx]; +} + +void DiagonalImpl( + hipStream_t stream, + const void* input_data, + const int64_t input_rank, + const int64_t dim_1, + const int64_t dim_2, + const TArray input_strides, + void* output_data, + const TArray output_strides, + const size_t output_size, + size_t element_size) { + if (output_size > 0) { + int blocksPerGrid = static_cast((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock); + + switch (element_size) { + case sizeof(int32_t): + hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, + reinterpret_cast::MappedType*>(input_data), input_rank, dim_1, dim_2, + input_strides, reinterpret_cast::MappedType*>(output_data), output_strides, + output_size); + break; + + case sizeof(int64_t): + hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, + reinterpret_cast::MappedType*>(input_data), input_rank, dim_1, dim_2, + input_strides, reinterpret_cast::MappedType*>(output_data), output_strides, + output_size); + break; + + case sizeof(int16_t): + hipLaunchKernelGGL(HIP_KERNEL_NAME(_DiagonalKernel), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream, + reinterpret_cast(input_data), input_rank, dim_1, dim_2, + input_strides, reinterpret_cast(output_data), output_strides, + output_size); + break; + + // Should not hit this as we do not register kernel support for types that will run into this + default: + ORT_THROW("Einsum Op: Diagonal parsing unsupported"); + } + } +} + +} // namespace rocm + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.h new file mode 100644 index 0000000000..4742b5338e --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/shared_inc/rocm_utils.h" + +namespace onnxruntime { + +namespace rocm { + +void DiagonalImpl( + hipStream_t stream, + const void* input_data, + const int64_t input_rank, + const int64_t dim_1, + const int64_t dim_2, + const TArray input_strides, + void* output_data, + const TArray output_strides, + const size_t output_size, + size_t element_size); + +} // namespace rocm + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index a3cda47a98..0b7b608d1a 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1773,7 +1773,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 13 BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 9d1bf63b60..ef5c9bb8a9 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -352,6 +352,8 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose) { test.Run(); } +// ROCm doesn't support double +#ifndef USE_ROCM TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); test.AddAttribute("equation", "iji->ji"); @@ -359,6 +361,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) { test.AddOutput("o", {2, 2}, {1., 2., 3., 4.}); test.Run(); } +#endif TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_int32) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 09660e029f..7d1d7cdda6 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -235,6 +235,7 @@ def hipify(src_file_path, dst_file_path): s = s.replace("hipblasDestroy", "rocblas_destroy_handle") s = s.replace("hipblasSetStream", "rocblas_set_stream") s = s.replace("HIPBLAS_OP_T", "rocblas_operation_transpose") + s = s.replace("HIPBLAS_OP_N", "rocblas_operation_none") s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels") s = s.replace("cudaEvent", "hipEvent")