mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Remove cast to OpKernelContextInternal to get threadpool and directly use OpKernelContext. (#3523)
This commit is contained in:
parent
06b63975c0
commit
4fa88a0a23
2 changed files with 2 additions and 6 deletions
|
|
@ -18,15 +18,13 @@
|
|||
#include "orttraining/training_ops/cpu/nn/conv_grad.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
Status ConvGrad<T>::Compute(OpKernelContext* context) const {
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool();
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
|
||||
const Tensor* dY = context->Input<Tensor>(0);
|
||||
const Tensor* X = context->Input<Tensor>(1);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@
|
|||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "gsl/gsl"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -91,8 +90,7 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
|
|||
scaledata + i, nullptr);
|
||||
}
|
||||
|
||||
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
|
||||
concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool();
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
math::Gemm<float>(CblasNoTrans, CblasNoTrans, n, d, 1, -1,
|
||||
scaledata, sum_multiplier_.data(), 1,
|
||||
dXdata, tp);
|
||||
|
|
|
|||
Loading…
Reference in a new issue