[ROCm] Remove redundant ep field in softmax (#17048)

This commit is contained in:
cloudhan 2023-08-17 11:53:30 +08:00 committed by GitHub
parent 5249b7ab7c
commit 049adb9f31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 4 additions and 22 deletions

View file

@ -152,7 +152,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
auto temp_input = Tensor::Create(X->DataType(), TensorShape(transposed_input_dims), alloc);
// Perform the transpose
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(cuda_ep_->GetDeviceProp(),
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
Stream(ctx),
GetCublasHandle(ctx),
permutation, *X, *temp_input));
@ -192,7 +192,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
if (is_transpose_required) {
// Perform the transpose to get the axes back to the original ordering
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(cuda_ep_->GetDeviceProp(),
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
Stream(ctx),
GetCublasHandle(ctx),
permutation, *intermediate_output, *Y));

View file

@ -46,11 +46,6 @@ class Softmax final : public CudaKernel {
}
log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax";
// We need to cast away the const as PerThreadCublasHandle() is currently a non-const method
// TODO: Clean up the CUDAExecutionProvider interface to avoid this
cuda_ep_ = const_cast<CUDAExecutionProvider*>(
static_cast<const CUDAExecutionProvider*>(info.GetExecutionProvider()));
}
Status ComputeInternal(OpKernelContext* context) const override;
@ -59,10 +54,6 @@ class Softmax final : public CudaKernel {
int64_t axis_;
bool log_softmax_;
int opset_;
// We need to access to the CUDA EP instance to get the cublas handle to use
// for transposing(if applicable)
CUDAExecutionProvider* cuda_ep_;
};
} // namespace cuda

View file

@ -152,7 +152,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
auto temp_input = Tensor::Create(X->DataType(), TensorShape(transposed_input_dims), alloc);
// Perform the transpose
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(),
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
Stream(ctx),
GetRocblasHandle(ctx),
permutation, *X, *temp_input));
@ -194,7 +194,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
if (is_transpose_required) {
// Perform the transpose to get the axes back to the original ordering
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(),
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
Stream(ctx),
GetRocblasHandle(ctx),
permutation, *intermediate_output, *Y));

View file

@ -51,11 +51,6 @@ class Softmax final : public RocmKernel {
}
log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax";
// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method
// TODO: Clean up the ROCMExecutionProvider interface to avoid this
rocm_ep_ = const_cast<ROCMExecutionProvider*>(
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider()));
}
Status ComputeInternal(OpKernelContext* context) const override;
@ -64,10 +59,6 @@ class Softmax final : public RocmKernel {
int64_t axis_;
bool log_softmax_;
int opset_;
// We need to access to the ROCM EP instance to get the rocblas handle to use
// for transposing(if applicable)
ROCMExecutionProvider* rocm_ep_;
};
} // namespace rocm