mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
miopen common
This commit is contained in:
parent
554184bcc4
commit
4c1db50df5
4 changed files with 71 additions and 21 deletions
|
|
@ -41,16 +41,36 @@ Status MiopenTensor::Set(const std::vector<int64_t>& input_dims, miopenDataType_
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MiopenTensor::Set(const MiopenTensor& x_desc, miopenBatchNormMode_t mode) {
|
||||
ORT_RETURN_IF_ERROR(CreateTensorIfNeeded());
|
||||
MIOPEN_RETURN_IF_ERROR(miopenDeriveBNTensorDescriptor(tensor_, x_desc, mode));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
miopenDataType_t MiopenTensor::GetDataType() {
|
||||
ORT_THROW("miopen engine currently supports only single/half precision data types.");
|
||||
ORT_THROW("miopen engine currently supports only single/half/int32/int8 precision data types.");
|
||||
}
|
||||
|
||||
template<>
|
||||
miopenDataType_t MiopenTensor::GetDataType<float>() {
|
||||
return miopenFloat;
|
||||
}
|
||||
|
||||
template <>
|
||||
miopenDataType_t MiopenTensor::GetDataType<float>() { return miopenFloat; }
|
||||
miopenDataType_t MiopenTensor::GetDataType<half>() {
|
||||
return miopenHalf;
|
||||
}
|
||||
|
||||
template <>
|
||||
miopenDataType_t MiopenTensor::GetDataType<half>() { return miopenHalf; }
|
||||
miopenDataType_t MiopenTensor::GetDataType<int32_t>() {
|
||||
return miopenInt32;
|
||||
}
|
||||
|
||||
template <>
|
||||
miopenDataType_t MiopenTensor::GetDataType<int8_t>() {
|
||||
return miopenInt8;
|
||||
}
|
||||
|
||||
template <>
|
||||
const float Consts<float>::One = 1;
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class MiopenTensor final {
|
|||
public:
|
||||
MiopenTensor();
|
||||
~MiopenTensor();
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MiopenTensor);
|
||||
|
||||
Status Set(const std::vector<int64_t>& input_dims, miopenDataType_t dataType);
|
||||
Status Set(const MiopenTensor& x_desc, miopenBatchNormMode_t mode);
|
||||
|
|
|
|||
|
|
@ -162,6 +162,8 @@ Status ReduceKernel<allow_multi_axes>::ReduceKernelShared(
|
|||
miopenReduceTensorOp_t miopen_reduce_op,
|
||||
std::vector<int64_t>& /*output_dims*/) const {
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
//typedef typename ToHipType<OutT>::MappedType HipOutT;
|
||||
//miopenDataType_t miopen_type_X = MiopenTensor::GetDataType<HipT>();
|
||||
const auto rank = input_shape.NumDimensions();
|
||||
|
||||
// Block of fast matrix row reduction.
|
||||
|
|
|
|||
|
|
@ -10,21 +10,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
enum miopenReduceTensorOp_t {
|
||||
MIOPEN_REDUCE_TENSOR_ADD,
|
||||
MIOPEN_REDUCE_TENSOR_MUL,
|
||||
MIOPEN_REDUCE_TENSOR_MIN,
|
||||
MIOPEN_REDUCE_TENSOR_MAX,
|
||||
MIOPEN_REDUCE_TENSOR_AVG,
|
||||
MIOPEN_REDUCE_TENSOR_NORM1,
|
||||
MIOPEN_REDUCE_TENSOR_NORM2,
|
||||
};
|
||||
|
||||
enum miopenReduceTensorIndices_t {
|
||||
MIOPEN_REDUCE_TENSOR_NO_INDICES,
|
||||
MIOPEN_REDUCE_TENSOR_FLATTENED_INDICES,
|
||||
};
|
||||
|
||||
namespace ReductionOps {
|
||||
|
||||
// Implementation that holds the core logic of reduction op processing
|
||||
|
|
@ -133,7 +118,8 @@ class ReduceL1 final : public ReduceKernel<true> {
|
|||
ReduceL1(const OpKernelInfo& info) : ReduceKernel<true>(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_NORM1);
|
||||
//return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_NORM1);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "MIOpen does not yet support reduce norm1.");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -143,7 +129,8 @@ class ReduceL2 final : public ReduceKernel<true> {
|
|||
ReduceL2(const OpKernelInfo& info) : ReduceKernel<true>(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_NORM2);
|
||||
//return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_NORM2);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "MIOpen does not yet support reduce norm2.");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -163,7 +150,8 @@ class ReduceMean final : public ReduceKernel<true> {
|
|||
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_AVG);
|
||||
//return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_AVG);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "MIOpen does not yet support reduce avg.");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -248,5 +236,44 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr
|
|||
bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
|
||||
const TensorShape* input_shape_override = nullptr);
|
||||
|
||||
// ROCM's reduction descriptor miopenReduceTensorDescriptor_t is a pointer so
|
||||
// it's safer to wrap it with automatically memory deleter as MiopenReduceDescriptor.
|
||||
// An implicit caster from MiopenReduceDescriptor to miopenReduceTensorDescriptor_t
|
||||
// is implemented below, so ROCM can seamlessly work.
|
||||
class MiopenReduceDescriptor final {
|
||||
public:
|
||||
MiopenReduceDescriptor() : desc_(nullptr) {
|
||||
}
|
||||
|
||||
~MiopenReduceDescriptor() {
|
||||
if (desc_ != nullptr) {
|
||||
miopenDestroyReduceTensorDescriptor(desc_);
|
||||
desc_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
MiopenReduceDescriptor(const MiopenReduceDescriptor&) = delete;
|
||||
MiopenReduceDescriptor& operator=(const MiopenReduceDescriptor&) = delete;
|
||||
|
||||
Status Set(miopenReduceTensorOp_t op, miopenDataType_t type, miopenReduceTensorIndices_t indices) {
|
||||
if (!desc_)
|
||||
MIOPEN_RETURN_IF_ERROR(miopenCreateReduceTensorDescriptor(&desc_));
|
||||
|
||||
MIOPEN_RETURN_IF_ERROR(miopenSetReduceTensorDescriptor(
|
||||
desc_,
|
||||
op,
|
||||
type,
|
||||
MIOPEN_PROPAGATE_NAN,
|
||||
indices,
|
||||
MIOPEN_32BIT_INDICES)); // currently only the 32-bit (unsigned int) type is supported.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
operator miopenReduceTensorDescriptor_t() const { return desc_; }
|
||||
|
||||
private:
|
||||
miopenReduceTensorDescriptor_t desc_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue