mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
[ROCm] Remove redundant ep field in softmax (#17048)
This commit is contained in:
parent
5249b7ab7c
commit
049adb9f31
4 changed files with 4 additions and 22 deletions
|
|
@ -152,7 +152,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
auto temp_input = Tensor::Create(X->DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
|
||||
// Perform the transpose
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(cuda_ep_->GetDeviceProp(),
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
|
||||
Stream(ctx),
|
||||
GetCublasHandle(ctx),
|
||||
permutation, *X, *temp_input));
|
||||
|
|
@ -192,7 +192,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
|
||||
if (is_transpose_required) {
|
||||
// Perform the transpose to get the axes back to the original ordering
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(cuda_ep_->GetDeviceProp(),
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
|
||||
Stream(ctx),
|
||||
GetCublasHandle(ctx),
|
||||
permutation, *intermediate_output, *Y));
|
||||
|
|
|
|||
|
|
@ -46,11 +46,6 @@ class Softmax final : public CudaKernel {
|
|||
}
|
||||
|
||||
log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax";
|
||||
|
||||
// We need to cast away the const as PerThreadCublasHandle() is currently a non-const method
|
||||
// TODO: Clean up the CUDAExecutionProvider interface to avoid this
|
||||
cuda_ep_ = const_cast<CUDAExecutionProvider*>(
|
||||
static_cast<const CUDAExecutionProvider*>(info.GetExecutionProvider()));
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
@ -59,10 +54,6 @@ class Softmax final : public CudaKernel {
|
|||
int64_t axis_;
|
||||
bool log_softmax_;
|
||||
int opset_;
|
||||
|
||||
// We need to access to the CUDA EP instance to get the cublas handle to use
|
||||
// for transposing(if applicable)
|
||||
CUDAExecutionProvider* cuda_ep_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
auto temp_input = Tensor::Create(X->DataType(), TensorShape(transposed_input_dims), alloc);
|
||||
|
||||
// Perform the transpose
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(),
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
|
||||
Stream(ctx),
|
||||
GetRocblasHandle(ctx),
|
||||
permutation, *X, *temp_input));
|
||||
|
|
@ -194,7 +194,7 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
|
||||
if (is_transpose_required) {
|
||||
// Perform the transpose to get the axes back to the original ordering
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(rocm_ep_->GetDeviceProp(),
|
||||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(GetDeviceProp(),
|
||||
Stream(ctx),
|
||||
GetRocblasHandle(ctx),
|
||||
permutation, *intermediate_output, *Y));
|
||||
|
|
|
|||
|
|
@ -51,11 +51,6 @@ class Softmax final : public RocmKernel {
|
|||
}
|
||||
|
||||
log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax";
|
||||
|
||||
// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method
|
||||
// TODO: Clean up the ROCMExecutionProvider interface to avoid this
|
||||
rocm_ep_ = const_cast<ROCMExecutionProvider*>(
|
||||
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider()));
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
@ -64,10 +59,6 @@ class Softmax final : public RocmKernel {
|
|||
int64_t axis_;
|
||||
bool log_softmax_;
|
||||
int opset_;
|
||||
|
||||
// We need to access to the ROCM EP instance to get the rocblas handle to use
|
||||
// for transposing(if applicable)
|
||||
ROCMExecutionProvider* rocm_ep_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
Loading…
Reference in a new issue