From b058dee64896dde347558937dc0057cd0791d8fb Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 2 Sep 2021 17:58:29 +1000 Subject: [PATCH] Fix a couple of issues mentioned in the PR comments. (#8936) --- onnxruntime/core/providers/cpu/tensor/copy.h | 22 +++++++++---------- .../core/providers/cpu/tensor/slice.cc | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/copy.h b/onnxruntime/core/providers/cpu/tensor/copy.h index 6eba9e3c31..ea7b48a27b 100644 --- a/onnxruntime/core/providers/cpu/tensor/copy.h +++ b/onnxruntime/core/providers/cpu/tensor/copy.h @@ -227,15 +227,15 @@ void StridedCopy(concurrency::ThreadPool* thread_pool, // call StridedCopy if there is a type with the same size as T in the set of EnabledTypes // e.g. if uint32_t is enabled all 4 byte types are supported template -bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool, - Tensor& dst, - std::ptrdiff_t dst_offset, - const std::vector& dst_strides, - const TensorShape& copy_shape, - const Tensor& src, - const std::vector& src_strides) { +inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool, + Tensor& dst, + std::ptrdiff_t dst_offset, + const std::vector& dst_strides, + const TensorShape& copy_shape, + const Tensor& src, + const std::vector& src_strides) { constexpr bool enabled = utils::HasTypeWithSameSize(); - if (enabled) { + if constexpr (enabled) { // T doesn't necessarily match the data type in src or dst so use reinterpret_cast. // it will be a type with the same size though, which is all that matters given we're only copying bits. StridedCopy(thread_pool, @@ -250,15 +250,15 @@ bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool, // EnabledTypes is an onnxruntime::TypeList with the enabled types in this build. // see "core/framework/element_type_lists.h" for default lists or the usage in -// onnxruntime/core/providers/cpu/tensor/concat.cc for +// onnxruntime/core/providers/cpu/tensor/concat.cc template Status DispatchStridedCopy(concurrency::ThreadPool* thread_pool, Tensor& dst, std::ptrdiff_t dst_offset, - const std::vector dst_strides, + const std::vector& dst_strides, const TensorShape& copy_shape, const Tensor& src, - const std::vector src_strides) { + const std::vector& src_strides) { ORT_ENFORCE(dst.DataType() == src.DataType(), "src and dst types must match"); bool supported = false; diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index 521e6ac6cd..1d91c5854e 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -255,7 +255,7 @@ static inline bool CallSliceImplIfEnabled(OpKernelContext* ctx, SliceOp::PrepareForComputeMetadata& compute_metadata, Status& status) { constexpr bool enabled = utils::HasTypeWithSameSize(); - if (enabled) { + if constexpr (enabled) { status = SliceImpl(ctx, input_tensor, compute_metadata); }