diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index ec8799ed67..9402232a24 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -152,7 +152,7 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { auto temp_input = Tensor::Create(X->DataType(), TensorShape(transposed_input_dims), alloc); // Perform the transpose - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(cuda_ep_->GetDeviceProp(), + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), Stream(ctx), GetCublasHandle(ctx), permutation, *X, *temp_input)); @@ -192,7 +192,7 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { if (is_transpose_required) { // Perform the transpose to get the axes back to the original ordering - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(cuda_ep_->GetDeviceProp(), + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), Stream(ctx), GetCublasHandle(ctx), permutation, *intermediate_output, *Y)); diff --git a/onnxruntime/core/providers/cuda/math/softmax.h b/onnxruntime/core/providers/cuda/math/softmax.h index 6c49ebd935..bbe63e66e6 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.h +++ b/onnxruntime/core/providers/cuda/math/softmax.h @@ -46,11 +46,6 @@ class Softmax final : public CudaKernel { } log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax"; - - // We need to cast away the const as PerThreadCublasHandle() is currently a non-const method - // TODO: Clean up the CUDAExecutionProvider interface to avoid this - cuda_ep_ = const_cast( - static_cast(info.GetExecutionProvider())); } Status ComputeInternal(OpKernelContext* context) const override; @@ -59,10 +54,6 @@ class Softmax final : public CudaKernel { int64_t axis_; bool log_softmax_; int opset_; - - // We need to access to the CUDA EP instance to get the cublas handle to use - // for transposing(if applicable) - CUDAExecutionProvider* cuda_ep_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index fe10bce6ea..5a07737d92 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -152,7 +152,7 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { auto temp_input = Tensor::Create(X->DataType(), TensorShape(transposed_input_dims), alloc); // Perform the transpose - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(), + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), Stream(ctx), GetRocblasHandle(ctx), permutation, *X, *temp_input)); @@ -194,7 +194,7 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { if (is_transpose_required) { // Perform the transpose to get the axes back to the original ordering - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(), + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(), Stream(ctx), GetRocblasHandle(ctx), permutation, *intermediate_output, *Y)); diff --git a/onnxruntime/core/providers/rocm/math/softmax.h b/onnxruntime/core/providers/rocm/math/softmax.h index 95c276cef9..49bfddad36 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.h +++ b/onnxruntime/core/providers/rocm/math/softmax.h @@ -51,11 +51,6 @@ class Softmax final : public RocmKernel { } log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax"; - - // We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method - // TODO: Clean up the ROCMExecutionProvider interface to avoid this - rocm_ep_ = const_cast( - static_cast(info.GetExecutionProvider())); } Status ComputeInternal(OpKernelContext* context) const override; @@ -64,10 +59,6 @@ class Softmax final : public RocmKernel { int64_t axis_; bool log_softmax_; int opset_; - - // We need to access to the ROCM EP instance to get the rocblas handle to use - // for transposing(if applicable) - ROCMExecutionProvider* rocm_ep_; }; } // namespace rocm