Remove cast to OpKernelContextInternal to get threadpool and directly use OpKernelContext. (#3523)

This commit is contained in:
edgchen1 2020-04-14 14:30:26 -07:00 committed by GitHub
parent 06b63975c0
commit 4fa88a0a23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 6 deletions

View file

@ -18,15 +18,13 @@
#include "orttraining/training_ops/cpu/nn/conv_grad.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/framework/op_kernel_context_internal.h"
namespace onnxruntime {
namespace contrib {
template <typename T>
Status ConvGrad<T>::Compute(OpKernelContext* context) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool();
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
const Tensor* dY = context->Input<Tensor>(0);
const Tensor* X = context->Input<Tensor>(1);

View file

@ -10,7 +10,6 @@
#include "core/providers/cpu/math/element_wise_ops.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "gsl/gsl"
#include "core/framework/op_kernel_context_internal.h"
namespace onnxruntime {
namespace contrib {
@ -91,8 +90,7 @@ Status SoftmaxGrad<T>::Compute(OpKernelContext* context) const {
scaledata + i, nullptr);
}
auto ctx_internal = static_cast<OpKernelContextInternal*>(context);
concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool();
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
math::Gemm<float>(CblasNoTrans, CblasNoTrans, n, d, 1, -1,
scaledata, sum_multiplier_.data(), 1,
dXdata, tp);