From 4fa88a0a23775dcc8bd70e16d201c6efc5f5a034 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 14 Apr 2020 14:30:26 -0700 Subject: [PATCH] Remove cast to OpKernelContextInternal to get threadpool and directly use OpKernelContext. (#3523) --- orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc | 4 +--- orttraining/orttraining/training_ops/cpu/op_gradients.cc | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc index cc1e7704bb..27a6729716 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/conv_grad.cc @@ -18,15 +18,13 @@ #include "orttraining/training_ops/cpu/nn/conv_grad.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { namespace contrib { template Status ConvGrad::Compute(OpKernelContext* context) const { - auto ctx_internal = static_cast(context); - concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); const Tensor* dY = context->Input(0); const Tensor* X = context->Input(1); diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc index 47e9032905..0ffcaf9ca4 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc @@ -10,7 +10,6 @@ #include "core/providers/cpu/math/element_wise_ops.h" #include "core/providers/cpu/math/matmul_helper.h" #include "gsl/gsl" -#include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { namespace contrib { @@ -91,8 +90,7 @@ Status SoftmaxGrad::Compute(OpKernelContext* context) const { scaledata + i, nullptr); } - auto ctx_internal = static_cast(context); - concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, scaledata, sum_multiplier_.data(), 1, dXdata, tp);