From 4c1db50df52c889bccc538a8dd83afec4e4bb095 Mon Sep 17 00:00:00 2001 From: Jesse Benson Date: Thu, 14 Jan 2021 13:57:18 -0800 Subject: [PATCH] miopen common --- .../core/providers/rocm/miopen_common.cc | 26 +++++++- .../core/providers/rocm/miopen_common.h | 1 + .../providers/rocm/reduction/reduction_ops.cc | 2 + .../providers/rocm/reduction/reduction_ops.h | 63 +++++++++++++------ 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/rocm/miopen_common.cc b/onnxruntime/core/providers/rocm/miopen_common.cc index 8f2054e4ee..6c18b202a7 100644 --- a/onnxruntime/core/providers/rocm/miopen_common.cc +++ b/onnxruntime/core/providers/rocm/miopen_common.cc @@ -41,16 +41,36 @@ Status MiopenTensor::Set(const std::vector& input_dims, miopenDataType_ return Status::OK(); } +Status MiopenTensor::Set(const MiopenTensor& x_desc, miopenBatchNormMode_t mode) { + ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); + MIOPEN_RETURN_IF_ERROR(miopenDeriveBNTensorDescriptor(tensor_, x_desc, mode)); + return Status::OK(); +} + template miopenDataType_t MiopenTensor::GetDataType() { - ORT_THROW("miopen engine currently supports only single/half precision data types."); + ORT_THROW("miopen engine currently supports only single/half/int32/int8 precision data types."); +} + +template<> +miopenDataType_t MiopenTensor::GetDataType() { + return miopenFloat; } template <> -miopenDataType_t MiopenTensor::GetDataType() { return miopenFloat; } +miopenDataType_t MiopenTensor::GetDataType() { + return miopenHalf; +} template <> -miopenDataType_t MiopenTensor::GetDataType() { return miopenHalf; } +miopenDataType_t MiopenTensor::GetDataType() { + return miopenInt32; +} + +template <> +miopenDataType_t MiopenTensor::GetDataType() { + return miopenInt8; +} template <> const float Consts::One = 1; diff --git a/onnxruntime/core/providers/rocm/miopen_common.h b/onnxruntime/core/providers/rocm/miopen_common.h index b71f5413e9..73d865dcfd 100644 --- a/onnxruntime/core/providers/rocm/miopen_common.h +++ b/onnxruntime/core/providers/rocm/miopen_common.h @@ -16,6 +16,7 @@ class MiopenTensor final { public: MiopenTensor(); ~MiopenTensor(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MiopenTensor); Status Set(const std::vector& input_dims, miopenDataType_t dataType); Status Set(const MiopenTensor& x_desc, miopenBatchNormMode_t mode); diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 0c981f35e7..38f855eaee 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -162,6 +162,8 @@ Status ReduceKernel::ReduceKernelShared( miopenReduceTensorOp_t miopen_reduce_op, std::vector& /*output_dims*/) const { typedef typename ToHipType::MappedType HipT; + //typedef typename ToHipType::MappedType HipOutT; + //miopenDataType_t miopen_type_X = MiopenTensor::GetDataType(); const auto rank = input_shape.NumDimensions(); // Block of fast matrix row reduction. diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.h b/onnxruntime/core/providers/rocm/reduction/reduction_ops.h index 936735aab7..f402851db1 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.h @@ -10,21 +10,6 @@ namespace onnxruntime { namespace rocm { -enum miopenReduceTensorOp_t { - MIOPEN_REDUCE_TENSOR_ADD, - MIOPEN_REDUCE_TENSOR_MUL, - MIOPEN_REDUCE_TENSOR_MIN, - MIOPEN_REDUCE_TENSOR_MAX, - MIOPEN_REDUCE_TENSOR_AVG, - MIOPEN_REDUCE_TENSOR_NORM1, - MIOPEN_REDUCE_TENSOR_NORM2, -}; - -enum miopenReduceTensorIndices_t { - MIOPEN_REDUCE_TENSOR_NO_INDICES, - MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES, -}; - namespace ReductionOps { // Implementation that holds the core logic of reduction op processing @@ -133,7 +118,8 @@ class ReduceL1 final : public ReduceKernel { ReduceL1(const OpKernelInfo& info) : ReduceKernel(info) {} Status ComputeInternal(OpKernelContext* ctx) const override { - return ComputeImpl(ctx, MIOPEN_REDUCE_TENSOR_NORM1); + //return ComputeImpl(ctx, MIOPEN_REDUCE_TENSOR_NORM1); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "MIOpen does not yet support reduce norm1."); } }; @@ -143,7 +129,8 @@ class ReduceL2 final : public ReduceKernel { ReduceL2(const OpKernelInfo& info) : ReduceKernel(info) {} Status ComputeInternal(OpKernelContext* ctx) const override { - return ComputeImpl(ctx, MIOPEN_REDUCE_TENSOR_NORM2); + //return ComputeImpl(ctx, MIOPEN_REDUCE_TENSOR_NORM2); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "MIOpen does not yet support reduce norm2."); } }; @@ -163,7 +150,8 @@ class ReduceMean final : public ReduceKernel { ReduceMean(const OpKernelInfo& info) : ReduceKernel(info) {} Status ComputeInternal(OpKernelContext* ctx) const override { - return ComputeImpl(ctx, MIOPEN_REDUCE_TENSOR_AVG); + //return ComputeImpl(ctx, MIOPEN_REDUCE_TENSOR_AVG); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "MIOpen does not yet support reduce avg."); } }; @@ -248,5 +236,44 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, const TensorShape* input_shape_override = nullptr); +// ROCM's reduction descriptor miopenReduceTensorDescriptor_t is a pointer so +// it's safer to wrap it with automatically memory deleter as MiopenReduceDescriptor. +// An implicit caster from MiopenReduceDescriptor to miopenReduceTensorDescriptor_t +// is implemented below, so ROCM can seamlessly work. +class MiopenReduceDescriptor final { + public: + MiopenReduceDescriptor() : desc_(nullptr) { + } + + ~MiopenReduceDescriptor() { + if (desc_ != nullptr) { + miopenDestroyReduceTensorDescriptor(desc_); + desc_ = nullptr; + } + } + + MiopenReduceDescriptor(const MiopenReduceDescriptor&) = delete; + MiopenReduceDescriptor& operator=(const MiopenReduceDescriptor&) = delete; + + Status Set(miopenReduceTensorOp_t op, miopenDataType_t type, miopenReduceTensorIndices_t indices) { + if (!desc_) + MIOPEN_RETURN_IF_ERROR(miopenCreateReduceTensorDescriptor(&desc_)); + + MIOPEN_RETURN_IF_ERROR(miopenSetReduceTensorDescriptor( + desc_, + op, + type, + MIOPEN_PROPAGATE_NAN, + indices, + MIOPEN_32BIT_INDICES)); // currently only the 32-bit (unsigned int) type is supported. + return Status::OK(); + } + + operator miopenReduceTensorDescriptor_t() const { return desc_; } + + private: + miopenReduceTensorDescriptor_t desc_; +}; + } // namespace rocm } // namespace onnxruntime \ No newline at end of file