mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Fix a couple of issues mentioned in the PR comments. (#8936)
This commit is contained in:
parent
ddbc8bc5fc
commit
b058dee648
2 changed files with 12 additions and 12 deletions
|
|
@ -227,15 +227,15 @@ void StridedCopy(concurrency::ThreadPool* thread_pool,
|
|||
// call StridedCopy if there is a type with the same size as T in the set of EnabledTypes
|
||||
// e.g. if uint32_t is enabled all 4 byte types are supported
|
||||
template <typename EnabledTypes, typename T>
|
||||
bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
|
||||
Tensor& dst,
|
||||
std::ptrdiff_t dst_offset,
|
||||
const std::vector<int64_t>& dst_strides,
|
||||
const TensorShape& copy_shape,
|
||||
const Tensor& src,
|
||||
const std::vector<int64_t>& src_strides) {
|
||||
inline bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
|
||||
Tensor& dst,
|
||||
std::ptrdiff_t dst_offset,
|
||||
const std::vector<int64_t>& dst_strides,
|
||||
const TensorShape& copy_shape,
|
||||
const Tensor& src,
|
||||
const std::vector<int64_t>& src_strides) {
|
||||
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
|
||||
if (enabled) {
|
||||
if constexpr (enabled) {
|
||||
// T doesn't necessarily match the data type in src or dst so use reinterpret_cast.
|
||||
// it will be a type with the same size though, which is all that matters given we're only copying bits.
|
||||
StridedCopy<T>(thread_pool,
|
||||
|
|
@ -250,15 +250,15 @@ bool StridedCopyIfEnabled(concurrency::ThreadPool* thread_pool,
|
|||
|
||||
// EnabledTypes is an onnxruntime::TypeList with the enabled types in this build.
|
||||
// see "core/framework/element_type_lists.h" for default lists or the usage in
|
||||
// onnxruntime/core/providers/cpu/tensor/concat.cc for
|
||||
// onnxruntime/core/providers/cpu/tensor/concat.cc
|
||||
template <typename EnabledDataTypes>
|
||||
Status DispatchStridedCopy(concurrency::ThreadPool* thread_pool,
|
||||
Tensor& dst,
|
||||
std::ptrdiff_t dst_offset,
|
||||
const std::vector<int64_t> dst_strides,
|
||||
const std::vector<int64_t>& dst_strides,
|
||||
const TensorShape& copy_shape,
|
||||
const Tensor& src,
|
||||
const std::vector<int64_t> src_strides) {
|
||||
const std::vector<int64_t>& src_strides) {
|
||||
ORT_ENFORCE(dst.DataType() == src.DataType(), "src and dst types must match");
|
||||
|
||||
bool supported = false;
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ static inline bool CallSliceImplIfEnabled(OpKernelContext* ctx,
|
|||
SliceOp::PrepareForComputeMetadata& compute_metadata,
|
||||
Status& status) {
|
||||
constexpr bool enabled = utils::HasTypeWithSameSize<EnabledTypes, T>();
|
||||
if (enabled) {
|
||||
if constexpr (enabled) {
|
||||
status = SliceImpl<T>(ctx, input_tensor, compute_metadata);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue