Fix a couple of issues mentioned in the PR comments. (#8936)

This commit is contained in:
Scott McKay 2021-09-02 17:58:29 +10:00 committed by GitHub
parent ddbc8bc5fc
commit b058dee648
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 12 deletions

View file

@ -227,15 +227,15 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
// call StridedCopy if there is a type with the same size as T in the set of EnabledTypes
// e.g. if uint32_t is enabled all 4 byte types are supported
template <typename EnabledTypes, typename T>
bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
Tensor& dst,
std::ptrdiff_t dst_offset,
const std::vector<int64_t>& dst_strides,
const TensorShape& copy_shape,
const Tensor& src,
const std::vector<int64_t>& src_strides) {
inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
Tensor& dst,
std::ptrdiff_t dst_offset,
const std::vector<int64_t>& dst_strides,
const TensorShape& copy_shape,
const Tensor& src,
const std::vector<int64_t>& src_strides) {
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
if (enabled) {
if constexpr (enabled) {
// T doesn't necessarily match the data type in src or dst so use reinterpret_cast.
// it will be a type with the same size though, which is all that matters given we're only copying bits.
StridedCopy<T>(thread_pool,
@ -250,15 +250,15 @@ bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
// EnabledTypes is an onnxruntime::TypeList with the enabled types in this build.
// see "core/framework/element_type_lists.h" for default lists or the usage in
// onnxruntime/core/providers/cpu/tensor/concat.cc for
// onnxruntime/core/providers/cpu/tensor/concat.cc
template <typename EnabledDataTypes>
Status DispatchStridedCopy(concurrency::ThreadPool* thread_pool,
Tensor& dst,
std::ptrdiff_t dst_offset,
const std::vector<int64_t> dst_strides,
const std::vector<int64_t>& dst_strides,
const TensorShape& copy_shape,
const Tensor& src,
const std::vector<int64_t> src_strides) {
const std::vector<int64_t>& src_strides) {
ORT_ENFORCE(dst.DataType() == src.DataType(), "src and dst types must match");
bool supported = false;

View file

@ -255,7 +255,7 @@ static inline bool CallSliceImplIfEnabled(OpKernelContext* ctx,
SliceOp::PrepareForComputeMetadata& compute_metadata,
Status& status) {
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
if (enabled) {
if constexpr (enabled) {
status = SliceImpl<T>(ctx, input_tensor, compute_metadata);
}