From 1a1dd4843d97423ef1d322d9e2eb5ec25322de3e Mon Sep 17 00:00:00 2001 From: Suffian Khan Date: Thu, 18 Mar 2021 10:09:45 -0700 Subject: [PATCH] Enable opset 13 for Rocm (#7047) * enable opset13 * import cuda changes for opset 13 softmax to rocm as well --- .../core/providers/rocm/math/softmax.cc | 107 +++++++- .../providers/rocm/rocm_execution_provider.cc | 258 +++++++++--------- 2 files changed, 226 insertions(+), 139 deletions(-) diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index 5d66c742e6..edeea72f59 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -6,6 +6,7 @@ #include "core/providers/common.h" #include "core/providers/rocm/miopen_common.h" #include "core/providers/rocm/shared_inc/accumulation_type.h" +#include "core/providers/rocm/tensor/transpose.h" namespace onnxruntime { namespace rocm { @@ -20,10 +21,8 @@ Status SoftMaxComputeHelper( int64_t axis) { typedef typename ToHipType::MappedType HipT; - const int64_t normalized_axis = HandleNegativeAxis(axis, input_shape.NumDimensions()); - - int64_t N = input_shape.SizeToDimension(normalized_axis); - int64_t D = input_shape.SizeFromDimension(normalized_axis); + int64_t N = input_shape.SizeToDimension(axis); + int64_t D = input_shape.SizeFromDimension(axis); auto Y_data = reinterpret_cast(Y); auto X_data = reinterpret_cast(X); @@ -113,17 +112,105 @@ template Status Softmax::ComputeInternal(OpKernelContext* ctx) const { const Tensor* X = ctx->Input(0); const TensorShape& input_shape{X->Shape()}; - const T* X_data = X->template Data(); - T* Y_data = ctx->Output(0, input_shape)->template MutableData(); + size_t rank = input_shape.NumDimensions(); + Tensor* Y = ctx->Output(0, input_shape); + // special case when there is a dim value of 0 in the shape. if (input_shape.Size() == 0) return Status::OK(); - if (log_softmax_) { - return SoftMaxComputeHelper(Stream(), X_data, input_shape, Y_data, MiopenHandle(), axis_); - } else { - return SoftMaxComputeHelper(Stream(), X_data, input_shape, Y_data, MiopenHandle(), axis_); + // handle negative and enforce axis is valid + const size_t axis = static_cast(HandleNegativeAxis(axis_, rank)); + + bool is_transpose_required = false; + Tensor transposed_input; + std::vector transposed_input_dims; + Tensor intermediate_output; // output that the softmax implementation will write into while using transposed input + std::vector permutation(rank); + + // The "semantic" meaning of axis has changed in opset-13. + // Please compare: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax + // with https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-11 for detailed explanations + // To account for the opset-13 behavior, our plan will be to transpose the "axis" dim to the innermost dim + // and perform softmax and then reverse the transpose. We can skip the transposing aspect if the axis is already + // the innermost dim + if (opset_ >= 13 && axis != (rank - 1)) { + is_transpose_required = true; } + + if (is_transpose_required) { + AllocatorPtr alloc; + auto status = ctx->GetTempSpaceAllocator(&alloc); + if (!status.IsOK()) + return status; + + std::iota(std::begin(permutation), std::end(permutation), 0); + + // swap the innermost dim with the dim corresponding to axis + permutation[axis] = rank - 1; + permutation[rank - 1] = axis; + + transposed_input_dims.reserve(rank); + for (auto e : permutation) { + transposed_input_dims.push_back(input_shape[e]); + } + + // Allocate a temporary tensor to hold transposed input + Tensor temp_input(X->DataType(), TensorShape(transposed_input_dims), alloc); + + // Perform the transpose + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(), + Stream(), + RocblasHandle(), + permutation, *X, temp_input)); + transposed_input = std::move(temp_input); + + // Allocate memory for the intermediate output + Tensor temp_output(Y->DataType(), TensorShape(transposed_input_dims), alloc); + intermediate_output = std::move(temp_output); + } + + const T* X_data = nullptr; + T* Y_data = nullptr; + const TensorShape* compute_input_shape = nullptr; + + if (is_transpose_required) { // use intermediate buffers to compute the softmax values + X_data = transposed_input.template Data(); + Y_data = intermediate_output.template MutableData(); + compute_input_shape = &transposed_input.Shape(); + } else { // use the node input/output directly + X_data = X->template Data(); + Y_data = Y->template MutableData(); + compute_input_shape = &input_shape; + } + + Status status; + if (log_softmax_) { + status = SoftMaxComputeHelper(Stream(), X_data, *compute_input_shape, Y_data, MiopenHandle(), + is_transpose_required ? static_cast(rank) - 1 + : static_cast(axis)); + } else { + status = SoftMaxComputeHelper(Stream(), X_data, *compute_input_shape, Y_data, MiopenHandle(), + is_transpose_required ? static_cast(rank) - 1 + : static_cast(axis)); + } + + if (!status.IsOK()) + return status; + + if (is_transpose_required) { + std::vector reverse_permutation(rank); + for (size_t i = 0, end = permutation.size(); i < end; ++i) { + reverse_permutation[permutation[i]] = i; + } + // Perform the transpose to get the axes back to the original ordering + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(), + Stream(), + RocblasHandle(), + reverse_permutation, intermediate_output, *Y)); + } + + return Status::OK(); } #define SPECIALIZED_COMPUTE(T) \ diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 9f020477eb..2e93fae664 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1503,28 +1503,28 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // OpSet 13 - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1532,71 +1532,71 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1612,46 +1612,46 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1664,10 +1664,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1686,7 +1686,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1694,12 +1694,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) {