From d10d66cc84be67c2d7b4b2bd323e768560c4a390 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Tue, 8 Nov 2022 23:58:05 -0500 Subject: [PATCH] Cjian/c4244 round 1a (#13483) ### Description Redo the round using gsl:narrow and SafeInt ### Motivation and Context --- onnxruntime/contrib_ops/cpu/activations.h | 4 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 18 +- onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc | 8 +- onnxruntime/contrib_ops/cpu/cdist.cc | 23 +- .../core/providers/cpu/tensor/upsample.cc | 251 +++++++++--------- 5 files changed, 153 insertions(+), 151 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index 4a4f76c2ef..810ea5066a 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -10,7 +10,7 @@ #include "core/platform/threadpool.h" #include #include "core/providers/cpu/element_wise_ranged_transform.h" - +using onnxruntime::narrow; namespace onnxruntime { namespace functors { @@ -82,7 +82,7 @@ class Gelu : public OpKernel { p_output[i] = value * static_cast(M_SQRT1_2); } - MlasComputeErf(p_output, p_output, gsl::narrow_cast(count)); + MlasComputeErf(p_output, p_output, narrow(count)); for (int64_t i = 0; i < count; i++) { p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 85dec5b4c7..d60d7f64ee 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -11,7 +11,7 @@ #include "core/platform/threadpool.h" using onnxruntime::concurrency::ThreadPool; - +using onnxruntime::narrow; namespace onnxruntime { namespace contrib { @@ -75,7 +75,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, return false; } - size_t loop_len = gsl::narrow_cast(num_heads_); + size_t loop_len = narrow(num_heads_); size_t packed_weights_data_size = packb_size * loop_len; // The same size would be computed by AllocArray() below auto* packed_weights_data = static_cast(alloc->AllocArray(packb_size, loop_len)); @@ -124,13 +124,13 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr } const auto* weights_data = weights.Data(); - const size_t input_hidden_size = gsl::narrow_cast(weights_dims[0]); + const size_t input_hidden_size = narrow(weights_dims[0]); size_t q_hidden_size, k_hidden_size, v_hidden_size; if (qkv_hidden_sizes_.size() != 0) { - q_hidden_size = gsl::narrow_cast(qkv_hidden_sizes_[0]); - k_hidden_size = gsl::narrow_cast(qkv_hidden_sizes_[1]); - v_hidden_size = gsl::narrow_cast(qkv_hidden_sizes_[2]); + q_hidden_size = narrow(qkv_hidden_sizes_[0]); + k_hidden_size = narrow(qkv_hidden_sizes_[1]); + v_hidden_size = narrow(qkv_hidden_sizes_[2]); if (q_hidden_size == 0 || k_hidden_size == 0 || v_hidden_size == 0) { return Status::OK(); @@ -140,7 +140,7 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr return Status::OK(); } } else { - const size_t hidden_size_x3 = gsl::narrow_cast(weights_dims[1]); + const size_t hidden_size_x3 = narrow(weights_dims[1]); const size_t hidden_size = hidden_size_x3 / 3; if (hidden_size % num_heads_ != 0) { @@ -240,8 +240,8 @@ Status Attention::Compute(OpKernelContext* context) const { BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator))); auto Q = reinterpret_cast(gemm_data); - auto K = Q + gsl::narrow_cast(batch_size) * sequence_length * parameters.hidden_size; - auto V = K + gsl::narrow_cast(batch_size) * sequence_length * parameters.hidden_size; + auto K = Q + narrow(batch_size) * sequence_length * parameters.hidden_size; + auto V = K + narrow(batch_size) * sequence_length * parameters.hidden_size; T* QKV[3] = {Q, K, V}; const int qkv_head_size[3] = {parameters.head_size, parameters.head_size, parameters.v_head_size}; diff --git a/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc b/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc index b6050534c8..006298e648 100644 --- a/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc +++ b/onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc @@ -11,7 +11,7 @@ #include "core/providers/common.h" #include "core/util/math_cpuonly.h" #include "core/mlas/inc/mlas.h" - +using onnxruntime::narrow; namespace onnxruntime { namespace contrib { @@ -60,7 +60,7 @@ Status BiasGelu::Compute(OpKernelContext* context) const { p_output[i] = value * (static_cast(C) * value * value + static_cast(B)); } - MlasComputeTanh(p_output, p_output,gsl::narrow_cast(count)); + MlasComputeTanh(p_output, p_output,narrow(count)); for (int64_t i = 0; i < count; i++) { p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); @@ -106,7 +106,7 @@ void BiasGelu::AddBiasGelu( temp[i] = value * 0.5f; } - MlasComputeTanh(output, output,gsl::narrow_cast(count)); + MlasComputeTanh(output, output,narrow(count)); for (int64_t i = 0; i < count; i++) { output[i] = temp[i] * (output[i] + 1.0f); @@ -118,7 +118,7 @@ void BiasGelu::AddBiasGelu( temp[i] = value * 0.5f; } - MlasComputeErf(output, output,gsl::narrow_cast(count)); + MlasComputeErf(output, output,narrow(count)); for (int64_t i = 0; i < count; i++) { output[i] = temp[i] * (output[i] + 1.0f); diff --git a/onnxruntime/contrib_ops/cpu/cdist.cc b/onnxruntime/contrib_ops/cpu/cdist.cc index 5365537b0d..d0ed81a9a6 100644 --- a/onnxruntime/contrib_ops/cpu/cdist.cc +++ b/onnxruntime/contrib_ops/cpu/cdist.cc @@ -3,11 +3,12 @@ #include "cdist.h" #include "core/common/common.h" +#include "core/common/safeint.h" #include "core/framework/op_kernel.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" #include "core/mlas/inc/mlas.h" - +using onnxruntime::narrow; namespace onnxruntime { namespace contrib { #define DEFINE_KERNEL(data_type) \ @@ -35,19 +36,19 @@ static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, co // ReduceSumSquare for A std::vector a_ss; - a_ss.resize(gsl::narrow_cast(m)); + a_ss.resize(narrow(m)); const auto* cur_a = a_data; for (int64_t i = 0; i < m; ++i) { - a_ss[gsl::narrow_cast(i)] = ConstEigenVectorMap(cur_a, gsl::narrow_cast(k)).squaredNorm(); + a_ss[narrow(i)] = ConstEigenVectorMap(cur_a, narrow(k)).squaredNorm(); cur_a += k; } // ReduceSumSquare for B std::vector b_ss; - b_ss.resize(gsl::narrow_cast(n)); + b_ss.resize(narrow(n)); const auto* cur_b = b_data; for (int64_t i = 0; i < n; ++i) { - b_ss[gsl::narrow_cast(i)] = ConstEigenVectorMap(cur_b, gsl::narrow_cast(k)).squaredNorm(); + b_ss[narrow(i)] = ConstEigenVectorMap(cur_b, narrow(k)).squaredNorm(); cur_b += k; } @@ -71,19 +72,19 @@ static void CalculateSqeuclidean(const Tensor& a, const Tensor& b, Tensor& c, co ORT_UNUSED_PARAMETER(threadpool); // https://eigen.tuxfamily.org/dox/TopicWritingEfficientProductExpression.html - auto out_map = EigenMatrixMapRowMajor(c_data, gsl::narrow_cast(m), gsl::narrow_cast(n)); + auto out_map = EigenMatrixMapRowMajor(c_data, SafeInt(m), SafeInt(n)); out_map.noalias() = static_cast(-2.) * - (ConstEigenMatrixMapRowMajor(a_data, gsl::narrow_cast(m), gsl::narrow_cast(k)) * - ConstEigenMatrixMapRowMajor(b_data, gsl::narrow_cast(n), gsl::narrow_cast(k)).transpose()); + (ConstEigenMatrixMapRowMajor(a_data, SafeInt(m), SafeInt(k)) * + ConstEigenMatrixMapRowMajor(b_data, SafeInt(n), SafeInt(k)).transpose()); #endif // add a_ss and b_ss, with broadcast // output shape is {m, n} auto* cur_out = c_data; for (int64_t i = 0; i < m; ++i) { - T a_val = a_ss[gsl::narrow_cast(i)]; + T a_val = a_ss[narrow(i)]; for (int64_t j = 0; j < n; ++j) { - *cur_out = (*cur_out + a_val) + b_ss[gsl::narrow_cast(j)]; + *cur_out = (*cur_out + a_val) + b_ss[narrow(j)]; ++cur_out; } } @@ -114,7 +115,7 @@ common::Status CDist::Compute(OpKernelContext* context) const { T* output = C->MutableData(); CalculateSqeuclidean(*A, *B, *C, tp); - auto map_out = EigenVectorArrayMap(output, gsl::narrow_cast(output_shape.Size())); + auto map_out = EigenVectorArrayMap(output, narrow(output_shape.Size())); // because we use GEMM in CalculateSqeuclidean there's a slight chance a number extremely close to zero // could be negative, so we need to run abs() to avoid NaN's in the results. diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 9173efe04f..1d4cf65c60 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -7,6 +7,7 @@ using namespace onnxruntime::common; using namespace std; +using onnxruntime::narrow; namespace onnxruntime { #define REGISTER_VERSIONED_TYPED_KERNEL(T, start, end) \ @@ -82,7 +83,7 @@ static std::vector UpsampleNearestSetupRank1InputMapping( if (input_dim0_idx < 0) input_dim0_idx = 0; } - input_mapping[gsl::narrow_cast(output_dim0_idx)]= input_dim0_idx; + input_mapping[narrow(output_dim0_idx)]= input_dim0_idx; } return input_mapping; @@ -98,37 +99,37 @@ UpsampleNearestSetupInputMappings(int64_t n_dim, bool extrapolation_enabled, const GetOriginalCoordinateFunc& get_original_coordinate, const GetNearestPixelFunc& get_nearest_pixel) { - std::vector> input_mappings(gsl::narrow_cast(n_dim)); + std::vector> input_mappings(narrow(n_dim)); for (int64_t axis = 0; axis < n_dim; ++axis) { - std::vector& input_mapping = input_mappings[gsl::narrow_cast(axis)]; - input_mapping.resize(gsl::narrow_cast(output_shape[gsl::narrow_cast(axis)])); + std::vector& input_mapping = input_mappings[narrow(axis)]; + input_mapping.resize(narrow(output_shape[narrow(axis)])); // When scale is 1.0, there is a one-to-one mapping between the dimension // in the input and the output and there is no need to apply the co-ordinate // transformation which should only be done when there is "resizing" required - if (scales[gsl::narrow_cast(axis)] == 1.0f) { - for (int64_t dim = 0; dim < output_shape[gsl::narrow_cast(axis)]; dim++) { - input_mapping[gsl::narrow_cast(dim)] = dim * input_dim_factor[gsl::narrow_cast(axis)]; + if (scales[narrow(axis)] == 1.0f) { + for (int64_t dim = 0; dim < output_shape[narrow(axis)]; dim++) { + input_mapping[narrow(dim)] = dim * input_dim_factor[narrow(axis)]; } continue; } // scale != 1.0 const int64_t input_size = input_dim_factor[0] * input_shape[0]; - for (int64_t dim = 0; dim < output_shape[gsl::narrow_cast(axis)]; dim++) { + for (int64_t dim = 0; dim < output_shape[narrow(axis)]; dim++) { float original_dim = get_original_coordinate(static_cast(dim), - scales[gsl::narrow_cast(axis)], - static_cast(output_shape[gsl::narrow_cast(axis)]), - static_cast(input_shape[gsl::narrow_cast(axis)]), - roi[gsl::narrow_cast(axis)], roi[gsl::narrow_cast(n_dim + axis)]); + scales[narrow(axis)], + static_cast(output_shape[narrow(axis)]), + static_cast(input_shape[narrow(axis)]), + roi[narrow(axis)], roi[SafeInt(n_dim) + axis]); - bool need_extrapolation = (extrapolation_enabled && (original_dim < 0 || original_dim > input_shape[gsl::narrow_cast(axis)] - 1)); - int64_t input_dim = get_nearest_pixel(original_dim, scales[gsl::narrow_cast(axis)] < 1); - if (input_dim >= input_shape[gsl::narrow_cast(axis)]) input_dim = input_shape[gsl::narrow_cast(axis)] - 1; + bool need_extrapolation = (extrapolation_enabled && (original_dim < 0 || original_dim > input_shape[narrow(axis)] - 1)); + int64_t input_dim = get_nearest_pixel(original_dim, scales[narrow(axis)] < 1); + if (input_dim >= input_shape[narrow(axis)]) input_dim = input_shape[narrow(axis)] - 1; if (input_dim < 0) input_dim = 0; - input_mapping[gsl::narrow_cast(dim)] = need_extrapolation ? (-input_size) : (input_dim * input_dim_factor[gsl::narrow_cast(axis)]); + input_mapping[narrow(dim)] = need_extrapolation ? (-input_size) : (input_dim * input_dim_factor[narrow(axis)]); } } @@ -148,11 +149,11 @@ static Status UpsampleNearestImpl(const T* input, const GetNearestPixelFunc& get_nearest_pixel) { int64_t n_dim = static_cast(input_shape.NumDimensions()); - std::vector input_dim_counters(gsl::narrow_cast(n_dim)); - std::vector input_dim_factor(gsl::narrow_cast(n_dim)); - input_dim_factor[gsl::narrow_cast(n_dim - 1)] = 1; // initialize dimension factor + std::vector input_dim_counters(narrow(n_dim)); + std::vector input_dim_factor(narrow(n_dim)); + input_dim_factor[SafeInt(n_dim) - 1] = 1; // initialize dimension factor for (int64_t dim_idx = n_dim - 2; dim_idx >= 0; dim_idx--) { - input_dim_factor[gsl::narrow_cast(dim_idx)] = input_dim_factor[gsl::narrow_cast(dim_idx + 1)] * input_shape[gsl::narrow_cast(dim_idx + 1)]; + input_dim_factor[narrow(dim_idx)] = input_dim_factor[SafeInt(dim_idx) + 1] * input_shape[SafeInt(dim_idx) + 1]; } int64_t output_idx = 0; @@ -162,14 +163,14 @@ static Status UpsampleNearestImpl(const T* input, std::vector input_mapping = UpsampleNearestSetupRank1InputMapping(input_shape[0], output_shape[0], scales[0], - roi[0], roi[gsl::narrow_cast(n_dim + 0)], + roi[0], roi[narrow(n_dim + 0)], extrapolation_enabled, get_original_coordinate, get_nearest_pixel); for (int64_t output_dim0_idx = 0; output_dim0_idx < output_shape[0]; output_dim0_idx++) { - int64_t input_dim0_idx = input_mapping[gsl::narrow_cast(output_dim0_idx)]; - output[gsl::narrow_cast(output_dim0_idx)]= input_dim0_idx < 0 ? extrapolation_value : input[input_dim0_idx]; + int64_t input_dim0_idx = input_mapping[narrow(output_dim0_idx)]; + output[narrow(output_dim0_idx)]= input_dim0_idx < 0 ? extrapolation_value : input[input_dim0_idx]; } return Status::OK(); @@ -184,9 +185,9 @@ static Status UpsampleNearestImpl(const T* input, const std::vector& input_mapping_1 = input_mappings[1]; for (int64_t output_dim0_inx = 0; output_dim0_inx < output_shape[0]; output_dim0_inx++) { - int64_t input_idx_0 = input_mapping_0[gsl::narrow_cast(output_dim0_inx)]; + int64_t input_idx_0 = input_mapping_0[narrow(output_dim0_inx)]; for (int64_t output_dim1_inx = 0; output_dim1_inx < output_shape[1]; output_dim1_inx++) { - int64_t input_idx_1 = input_idx_0 + input_mapping_1[gsl::narrow_cast(output_dim1_inx)]; + int64_t input_idx_1 = input_idx_0 + input_mapping_1[narrow(output_dim1_inx)]; output[output_idx++] = (input_idx_1 < 0) ? extrapolation_value : input[input_idx_1]; } } @@ -199,11 +200,11 @@ static Status UpsampleNearestImpl(const T* input, const std::vector& input_mapping_2 = input_mappings[2]; for (int64_t output_dim0_inx = 0; output_dim0_inx < output_shape[0]; output_dim0_inx++) { - int64_t input_idx_0 = input_mapping_0[gsl::narrow_cast(output_dim0_inx)]; + int64_t input_idx_0 = input_mapping_0[narrow(output_dim0_inx)]; for (int64_t output_dim1_inx = 0; output_dim1_inx < output_shape[1]; output_dim1_inx++) { - int64_t input_idx_1 = input_idx_0 + input_mapping_1[gsl::narrow_cast(output_dim1_inx)]; + int64_t input_idx_1 = input_idx_0 + input_mapping_1[narrow(output_dim1_inx)]; for (int64_t output_dim2_inx = 0; output_dim2_inx < output_shape[2]; output_dim2_inx++) { - int64_t input_idx_2 = input_idx_1 + input_mapping_2[gsl::narrow_cast(output_dim2_inx)]; + int64_t input_idx_2 = input_idx_1 + input_mapping_2[narrow(output_dim2_inx)]; output[output_idx++] = (input_idx_2 < 0) ? extrapolation_value : input[input_idx_2]; } } @@ -218,14 +219,14 @@ static Status UpsampleNearestImpl(const T* input, const std::vector& input_mapping_3 = input_mappings[3]; for (int64_t output_dim0_inx = 0; output_dim0_inx < output_shape[0]; output_dim0_inx++) { - int64_t input_idx_0 = input_mapping_0[gsl::narrow_cast(output_dim0_inx)]; + int64_t input_idx_0 = input_mapping_0[narrow(output_dim0_inx)]; for (int64_t output_dim1_inx = 0; output_dim1_inx < output_shape[1]; output_dim1_inx++) { - int64_t input_idx_1 = input_idx_0 + input_mapping_1[gsl::narrow_cast(output_dim1_inx)]; + int64_t input_idx_1 = input_idx_0 + input_mapping_1[narrow(output_dim1_inx)]; for (int64_t output_dim2_inx = 0; output_dim2_inx < output_shape[2]; output_dim2_inx++) { - int64_t input_idx_2 = input_idx_1 + input_mapping_2[gsl::narrow_cast(output_dim2_inx)]; + int64_t input_idx_2 = input_idx_1 + input_mapping_2[narrow(output_dim2_inx)]; for (int64_t output_dim3_inx = 0; output_dim3_inx < output_shape[3]; output_dim3_inx++) { - int64_t input_idx_3 = input_idx_2 + input_mapping_3[gsl::narrow_cast(output_dim3_inx)]; - output[output_idx++] = (input_idx_3 < 0) ? static_cast(extrapolation_value) : input[gsl::narrow_cast(input_idx_3)]; + int64_t input_idx_3 = input_idx_2 + input_mapping_3[narrow(output_dim3_inx)]; + output[output_idx++] = (input_idx_3 < 0) ? static_cast(extrapolation_value) : input[narrow(input_idx_3)]; } } } @@ -235,20 +236,20 @@ static Status UpsampleNearestImpl(const T* input, std::vector output_dim_counter(n_dim); for (int64_t dim_idx = 0; dim_idx < n_dim; dim_idx++) { - input_idx += input_mappings[gsl::narrow_cast(dim_idx)][0 /* output_dim_counter[gsl::narrow_cast(dim_idx)] */]; + input_idx += input_mappings[narrow(dim_idx)][0 /* output_dim_counter[narrow(dim_idx)] */]; } for (int64_t output_size = output_shape.Size(); output_idx < output_size; output_idx++) { - output[gsl::narrow_cast(output_idx)] = (input_idx < 0) ? extrapolation_value : input[gsl::narrow_cast(input_idx)]; + output[narrow(output_idx)] = (input_idx < 0) ? extrapolation_value : input[narrow(input_idx)]; for (int64_t dim_idx = n_dim - 1; dim_idx >= 0; dim_idx--) { - input_idx -= input_mappings[gsl::narrow_cast(dim_idx)][gsl::narrow_cast(output_dim_counter[gsl::narrow_cast(dim_idx)])]; - if (++output_dim_counter[gsl::narrow_cast(dim_idx)] < output_shape[gsl::narrow_cast(dim_idx)]) { - input_idx += input_mappings[gsl::narrow_cast(dim_idx)][gsl::narrow_cast(output_dim_counter[gsl::narrow_cast(dim_idx)])]; + input_idx -= input_mappings[narrow(dim_idx)][narrow(output_dim_counter[narrow(dim_idx)])]; + if (++output_dim_counter[narrow(dim_idx)] < output_shape[narrow(dim_idx)]) { + input_idx += input_mappings[narrow(dim_idx)][narrow(output_dim_counter[narrow(dim_idx)])]; break; } - output_dim_counter[gsl::narrow_cast(dim_idx)] = 0; - input_idx += input_mappings[gsl::narrow_cast(dim_idx)][0 /* output_dim_counter[gsl::narrow_cast(dim_idx)] */]; + output_dim_counter[narrow(dim_idx)] = 0; + input_idx += input_mappings[narrow(dim_idx)][0 /* output_dim_counter[narrow(dim_idx)] */]; } } @@ -349,7 +350,7 @@ static Status UpsampleLinearImpl(const std::function(i)] = 0; + // output[narrow(i)] = 0; int64_t step = (1LL << n_dim) - 1; while (step >= 0) { @@ -365,7 +366,7 @@ static Status UpsampleLinearImpl(const std::function>= 1; } - // output[gsl::narrow_cast(i)] += input[old_idx] * w; + // output[narrow(i)] += input[old_idx] * w; apply(old_idx, i, w); step--; @@ -390,7 +391,7 @@ static Status UpsampleLinear(const T* input, std::fill_n(output, output_shape.Size(), T{}); auto apply = [&input, &output](size_t input_idx, size_t output_idx, float w) { - output[gsl::narrow_cast(output_idx)] += input[gsl::narrow_cast(input_idx)] * w; + output[narrow(output_idx)] += input[narrow(input_idx)] * w; }; return UpsampleLinearImpl(apply, input_shape, output_shape, scales, is_resize, roi, get_original_coordinate); @@ -471,16 +472,16 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t in_y1 = std::min(static_cast(in_y), input_height - 1); const int32_t in_y2 = std::min(in_y1 + 1, input_height - 1); - p.dy1[gsl::narrow_cast(y)] = std::fabs(in_y - in_y1); - p.dy2[gsl::narrow_cast(y)] = std::fabs(in_y - in_y2); + p.dy1[narrow(y)] = std::fabs(in_y - in_y1); + p.dy2[narrow(y)] = std::fabs(in_y - in_y2); if (in_y1 == in_y2) { - p.dy1[gsl::narrow_cast(y)] = 0.5f; - p.dy2[gsl::narrow_cast(y)] = 0.5f; + p.dy1[narrow(y)] = 0.5f; + p.dy2[narrow(y)] = 0.5f; } - p.input_width_mul_y1[gsl::narrow_cast(y)] = input_width * in_y1; - p.input_width_mul_y2[gsl::narrow_cast(y)] = input_width * in_y2; + p.input_width_mul_y1[narrow(y)] = input_width * in_y1; + p.input_width_mul_y2[narrow(y)] = input_width * in_y2; } const size_t width_rindex = is_nchw ? 0 : 1; @@ -496,14 +497,14 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, p.x_original.emplace_back(in_x); in_x = std::max(0.0f, std::min(in_x, static_cast(input_width - 1))); - p.in_x1[gsl::narrow_cast(x)] = std::min(static_cast(in_x), input_width - 1); - p.in_x2[gsl::narrow_cast(x)] = std::min(p.in_x1[gsl::narrow_cast(x)] + 1, input_width - 1); + p.in_x1[narrow(x)] = std::min(static_cast(in_x), input_width - 1); + p.in_x2[narrow(x)] = std::min(p.in_x1[narrow(x)] + 1, input_width - 1); - p.dx1[gsl::narrow_cast(x)] = std::fabs(in_x - p.in_x1[gsl::narrow_cast(x)]); - p.dx2[gsl::narrow_cast(x)] = std::fabs(in_x - p.in_x2[gsl::narrow_cast(x)]); - if (p.in_x1[gsl::narrow_cast(x)] == p.in_x2[gsl::narrow_cast(x)]) { - p.dx1[gsl::narrow_cast(x)] = 0.5f; - p.dx2[gsl::narrow_cast(x)] = 0.5f; + p.dx1[narrow(x)] = std::fabs(in_x - p.in_x1[narrow(x)]); + p.dx2[narrow(x)] = std::fabs(in_x - p.in_x2[narrow(x)]); + if (p.in_x1[narrow(x)] == p.in_x2[narrow(x)]) { + p.dx1[narrow(x)] = 0.5f; + p.dx2[narrow(x)] = 0.5f; } } @@ -578,16 +579,16 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t in_y1 = std::min(static_cast(in_y), input_height - 1); const int32_t in_y2 = std::min(in_y1 + 1, input_height - 1); - p.dy1_scale_10[gsl::narrow_cast(y)] = std::abs(in_y_scale_10 - in_y1 * (1 << 10)); - p.dy2_scale_10[gsl::narrow_cast(y)] = std::abs(in_y_scale_10 - in_y2 * (1 << 10)); + p.dy1_scale_10[narrow(y)] = std::abs(in_y_scale_10 - in_y1 * (1 << 10)); + p.dy2_scale_10[narrow(y)] = std::abs(in_y_scale_10 - in_y2 * (1 << 10)); if (in_y1 == in_y2) { - p.dy1_scale_10[gsl::narrow_cast(y)] = static_cast(0.5f * (1 << 10)); - p.dy2_scale_10[gsl::narrow_cast(y)] = static_cast(0.5f * (1 << 10)); + p.dy1_scale_10[narrow(y)] = static_cast(0.5f * (1 << 10)); + p.dy2_scale_10[narrow(y)] = static_cast(0.5f * (1 << 10)); } - p.input_width_mul_y1[gsl::narrow_cast(y)] = input_width * in_y1; - p.input_width_mul_y2[gsl::narrow_cast(y)] = input_width * in_y2; + p.input_width_mul_y1[narrow(y)] = input_width * in_y1; + p.input_width_mul_y2[narrow(y)] = input_width * in_y2; } const size_t width_rindex = is_nchw ? 0 : 1; @@ -604,14 +605,14 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, in_x = std::max(0.0f, std::min(in_x, static_cast(input_width - 1))); int32_t in_x_scale_10 = static_cast(in_x * (1 << 10)); - p.in_x1[gsl::narrow_cast(x)] = std::min(static_cast(in_x), input_width - 1); - p.in_x2[gsl::narrow_cast(x)] = std::min(p.in_x1[gsl::narrow_cast(x)] + 1, input_width - 1); + p.in_x1[narrow(x)] = std::min(static_cast(in_x), input_width - 1); + p.in_x2[narrow(x)] = std::min(p.in_x1[narrow(x)] + 1, input_width - 1); - p.dx1_scale_10[gsl::narrow_cast(x)] = std::abs(in_x_scale_10 - p.in_x1[gsl::narrow_cast(x)] * (1 << 10)); - p.dx2_scale_10[gsl::narrow_cast(x)] = std::abs(in_x_scale_10 - p.in_x2[gsl::narrow_cast(x)] * (1 << 10)); - if (p.in_x1[gsl::narrow_cast(x)] == p.in_x2[gsl::narrow_cast(x)]) { - p.dx1_scale_10[gsl::narrow_cast(x)] = static_cast(0.5f * (1 << 10)); - p.dx2_scale_10[gsl::narrow_cast(x)] = static_cast(0.5f * (1 << 10)); + p.dx1_scale_10[narrow(x)] = std::abs(in_x_scale_10 - p.in_x1[narrow(x)] * (1 << 10)); + p.dx2_scale_10[narrow(x)] = std::abs(in_x_scale_10 - p.in_x2[narrow(x)] * (1 << 10)); + if (p.in_x1[narrow(x)] == p.in_x2[narrow(x)]) { + p.dx1_scale_10[narrow(x)] = static_cast(0.5f * (1 << 10)); + p.dx2_scale_10[narrow(x)] = static_cast(0.5f * (1 << 10)); } } @@ -654,9 +655,9 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, const GetOriginalCoordinateFunc& get_original_coordinate) { TrilinearParams p; - p.z_original.reserve(gsl::narrow_cast(output_depth)); - p.y_original.reserve(gsl::narrow_cast(output_height)); - p.x_original.reserve(gsl::narrow_cast(output_width)); + p.z_original.reserve(narrow(output_depth)); + p.y_original.reserve(narrow(output_height)); + p.x_original.reserve(narrow(output_width)); // For each index in the output height and output width, cache its corresponding indices in the input // while multiplying it with the input stride for that dimension (cache because we don't have to re-compute @@ -716,16 +717,16 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, const int64_t in_z1 = std::min(static_cast(in_z), input_depth - 1); const int64_t in_z2 = std::min(in_z1 + 1, input_depth - 1); - p.dz1[gsl::narrow_cast(z)] = std::fabs(in_z - in_z1); - p.dz2[gsl::narrow_cast(z)] = std::fabs(in_z - in_z2); + p.dz1[narrow(z)] = std::fabs(in_z - in_z1); + p.dz2[narrow(z)] = std::fabs(in_z - in_z2); if (in_z1 == in_z2) { - p.dz1[gsl::narrow_cast(z)] = 0.5f; - p.dz2[gsl::narrow_cast(z)] = 0.5f; + p.dz1[narrow(z)] = 0.5f; + p.dz2[narrow(z)] = 0.5f; } - p.input_height_width_mul_z1[gsl::narrow_cast(z)] = input_height * input_width * in_z1; - p.input_height_width_mul_z2[gsl::narrow_cast(z)] = input_height * input_width * in_z2; + p.input_height_width_mul_z1[narrow(z)] = input_height * input_width * in_z1; + p.input_height_width_mul_z2[narrow(z)] = input_height * input_width * in_z2; } auto roi_y_start = roi.size() / 2 - 2; @@ -741,16 +742,16 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, const int64_t in_y1 = std::min(static_cast(in_y), input_height - 1); const int64_t in_y2 = std::min(in_y1 + 1, input_height - 1); - p.dy1[gsl::narrow_cast(y)] = std::fabs(in_y - in_y1); - p.dy2[gsl::narrow_cast(y)] = std::fabs(in_y - in_y2); + p.dy1[narrow(y)] = std::fabs(in_y - in_y1); + p.dy2[narrow(y)] = std::fabs(in_y - in_y2); if (in_y1 == in_y2) { - p.dy1[gsl::narrow_cast(y)] = 0.5f; - p.dy2[gsl::narrow_cast(y)] = 0.5f; + p.dy1[narrow(y)] = 0.5f; + p.dy2[narrow(y)] = 0.5f; } - p.input_width_mul_y1[gsl::narrow_cast(y)] = input_width * in_y1; - p.input_width_mul_y2[gsl::narrow_cast(y)] = input_width * in_y2; + p.input_width_mul_y1[narrow(y)] = input_width * in_y1; + p.input_width_mul_y2[narrow(y)] = input_width * in_y2; } auto roi_x_start = roi.size() / 2 - 1; @@ -764,14 +765,14 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, p.x_original.emplace_back(in_x); in_x = std::max(0.0f, std::min(in_x, static_cast(input_width - 1))); - p.in_x1[gsl::narrow_cast(x)] = std::min(static_cast(in_x), input_width - 1); - p.in_x2[gsl::narrow_cast(x)] = std::min(p.in_x1[gsl::narrow_cast(x)] + 1, input_width - 1); + p.in_x1[narrow(x)] = std::min(static_cast(in_x), input_width - 1); + p.in_x2[narrow(x)] = std::min(p.in_x1[narrow(x)] + 1, input_width - 1); - p.dx1[gsl::narrow_cast(x)] = std::fabs(in_x - p.in_x1[gsl::narrow_cast(x)]); - p.dx2[gsl::narrow_cast(x)] = std::fabs(in_x - p.in_x2[gsl::narrow_cast(x)]); - if (p.in_x1[gsl::narrow_cast(x)] == p.in_x2[gsl::narrow_cast(x)]) { - p.dx1[gsl::narrow_cast(x)] = 0.5f; - p.dx2[gsl::narrow_cast(x)] = 0.5f; + p.dx1[narrow(x)] = std::fabs(in_x - p.in_x1[narrow(x)]); + p.dx2[narrow(x)] = std::fabs(in_x - p.in_x2[narrow(x)]); + if (p.in_x1[narrow(x)] == p.in_x2[narrow(x)]) { + p.dx1[narrow(x)] = 0.5f; + p.dx2[narrow(x)] = 0.5f; } } @@ -820,35 +821,35 @@ void UpsampleTrilinear(int64_t batch_size, // when use_extrapolation is set and original index of x or y is out of the dim range // then use extrapolation_value as the output value. if (use_extrapolation && - ((p.z_original[gsl::narrow_cast(z)] < 0 || p.z_original[gsl::narrow_cast(z)] > static_cast(input_depth - 1)) || - (p.y_original[gsl::narrow_cast(y)] < 0 || p.y_original[gsl::narrow_cast(y)] > static_cast(input_height - 1)) || - (p.x_original[gsl::narrow_cast(x)] < 0 || p.x_original[gsl::narrow_cast(x)] > static_cast(input_width - 1)))) { + ((p.z_original[narrow(z)] < 0 || p.z_original[narrow(z)] > static_cast(input_depth - 1)) || + (p.y_original[narrow(y)] < 0 || p.y_original[narrow(y)] > static_cast(input_height - 1)) || + (p.x_original[narrow(x)] < 0 || p.x_original[narrow(x)] > static_cast(input_width - 1)))) { Ydata[output_width * output_height * z + output_width * y + x] = static_cast(extrapolation_value); continue; } // subscript ordering in the variable - (xyz) - T X111 = Xdata[p.input_height_width_mul_z1[gsl::narrow_cast(z)] + p.input_width_mul_y1[gsl::narrow_cast(y)] + p.in_x1[gsl::narrow_cast(x)]]; - T X211 = Xdata[p.input_height_width_mul_z1[gsl::narrow_cast(z)] + p.input_width_mul_y1[gsl::narrow_cast(y)] + p.in_x2[gsl::narrow_cast(x)]]; - T X121 = Xdata[p.input_height_width_mul_z1[gsl::narrow_cast(z)] + p.input_width_mul_y2[gsl::narrow_cast(y)] + p.in_x1[gsl::narrow_cast(x)]]; - T X221 = Xdata[p.input_height_width_mul_z1[gsl::narrow_cast(z)] + p.input_width_mul_y2[gsl::narrow_cast(y)] + p.in_x2[gsl::narrow_cast(x)]]; + T X111 = Xdata[p.input_height_width_mul_z1[narrow(z)] + p.input_width_mul_y1[narrow(y)] + p.in_x1[narrow(x)]]; + T X211 = Xdata[p.input_height_width_mul_z1[narrow(z)] + p.input_width_mul_y1[narrow(y)] + p.in_x2[narrow(x)]]; + T X121 = Xdata[p.input_height_width_mul_z1[narrow(z)] + p.input_width_mul_y2[narrow(y)] + p.in_x1[narrow(x)]]; + T X221 = Xdata[p.input_height_width_mul_z1[narrow(z)] + p.input_width_mul_y2[narrow(y)] + p.in_x2[narrow(x)]]; - T X112 = Xdata[p.input_height_width_mul_z2[gsl::narrow_cast(z)] + p.input_width_mul_y1[gsl::narrow_cast(y)] + p.in_x1[gsl::narrow_cast(x)]]; - T X212 = Xdata[p.input_height_width_mul_z2[gsl::narrow_cast(z)] + p.input_width_mul_y1[gsl::narrow_cast(y)] + p.in_x2[gsl::narrow_cast(x)]]; - T X122 = Xdata[p.input_height_width_mul_z2[gsl::narrow_cast(z)] + p.input_width_mul_y2[gsl::narrow_cast(y)] + p.in_x1[gsl::narrow_cast(x)]]; - T X222 = Xdata[p.input_height_width_mul_z2[gsl::narrow_cast(z)] + p.input_width_mul_y2[gsl::narrow_cast(y)] + p.in_x2[gsl::narrow_cast(x)]]; + T X112 = Xdata[p.input_height_width_mul_z2[narrow(z)] + p.input_width_mul_y1[narrow(y)] + p.in_x1[narrow(x)]]; + T X212 = Xdata[p.input_height_width_mul_z2[narrow(z)] + p.input_width_mul_y1[narrow(y)] + p.in_x2[narrow(x)]]; + T X122 = Xdata[p.input_height_width_mul_z2[narrow(z)] + p.input_width_mul_y2[narrow(y)] + p.in_x1[narrow(x)]]; + T X222 = Xdata[p.input_height_width_mul_z2[narrow(z)] + p.input_width_mul_y2[narrow(y)] + p.in_x2[narrow(x)]]; Ydata[output_width * output_height * z + output_width * y + x] = - static_cast(p.dx2[gsl::narrow_cast(x)] * p.dy2[gsl::narrow_cast(y)] * p.dz2[gsl::narrow_cast(z)] * X111 + - p.dx1[gsl::narrow_cast(x)] * p.dy2[gsl::narrow_cast(y)] * p.dz2[gsl::narrow_cast(z)] * X211 + - p.dx2[gsl::narrow_cast(x)] * p.dy1[gsl::narrow_cast(y)] * p.dz2[gsl::narrow_cast(z)] * X121 + - p.dx1[gsl::narrow_cast(x)] * p.dy1[gsl::narrow_cast(y)] * p.dz2[gsl::narrow_cast(z)] * X221 + + static_cast(p.dx2[narrow(x)] * p.dy2[narrow(y)] * p.dz2[narrow(z)] * X111 + + p.dx1[narrow(x)] * p.dy2[narrow(y)] * p.dz2[narrow(z)] * X211 + + p.dx2[narrow(x)] * p.dy1[narrow(y)] * p.dz2[narrow(z)] * X121 + + p.dx1[narrow(x)] * p.dy1[narrow(y)] * p.dz2[narrow(z)] * X221 + - p.dx2[gsl::narrow_cast(x)] * p.dy2[gsl::narrow_cast(y)] * p.dz1[gsl::narrow_cast(z)] * X112 + - p.dx1[gsl::narrow_cast(x)] * p.dy2[gsl::narrow_cast(y)] * p.dz1[gsl::narrow_cast(z)] * X212 + - p.dx2[gsl::narrow_cast(x)] * p.dy1[gsl::narrow_cast(y)] * p.dz1[gsl::narrow_cast(z)] * X122 + - p.dx1[gsl::narrow_cast(x)] * p.dy1[gsl::narrow_cast(y)] * p.dz1[gsl::narrow_cast(z)] * X222); + p.dx2[narrow(x)] * p.dy2[narrow(y)] * p.dz1[narrow(z)] * X112 + + p.dx1[narrow(x)] * p.dy2[narrow(y)] * p.dz1[narrow(z)] * X212 + + p.dx2[narrow(x)] * p.dy1[narrow(y)] * p.dz1[narrow(z)] * X122 + + p.dx1[narrow(x)] * p.dy1[narrow(y)] * p.dz1[narrow(z)] * X222); } } } @@ -905,7 +906,7 @@ float CubicInterpolation1D(const T* Xdata, float result = 0; for (int i = 0, j = -1; i < static_cast(CubicModeGridLength); i++, j++) { auto orig_data = GetDataForCoordinate(Xdata, x + j, y, input_height, input_width); - result += coeff_array[gsl::narrow_cast(i)] / coeff_sum * orig_data; + result += coeff_array[narrow(i)] / coeff_sum * orig_data; } cache[grid_start_pos] = result; @@ -933,10 +934,10 @@ void ResizeBiCubic(int64_t batch_size, T* Ydata, const GetOriginalCoordinateFunc& get_original_coordinate) { std::vector y_original; - y_original.reserve(gsl::narrow_cast(output_height)); + y_original.reserve(narrow(output_height)); std::vector x_original; - x_original.reserve(gsl::narrow_cast(output_width)); + x_original.reserve(narrow(output_width)); std::unordered_map> cubic_coeffs; std::unordered_map> coeff_to_1Dinterpolation_map; @@ -953,7 +954,7 @@ void ResizeBiCubic(int64_t batch_size, static_cast(input_height), roi[roi_y_start], roi[roi_y_end]); y_original.emplace_back(in_y); - auto s = y_original[gsl::narrow_cast(y)] - std::floor(y_original[gsl::narrow_cast(y)]); + auto s = y_original[narrow(y)] - std::floor(y_original[narrow(y)]); if (cubic_coeffs.find(s) == cubic_coeffs.end()) { cubic_coeffs[s] = GetCubicCoeffs(s, cubic_coeff_a); coeff_to_1Dinterpolation_map[s] = {}; @@ -969,7 +970,7 @@ void ResizeBiCubic(int64_t batch_size, static_cast(input_width), roi[roi_x_start], roi[roi_x_end]); x_original.emplace_back(in_x); - auto s = x_original[gsl::narrow_cast(x)] - std::floor(x_original[gsl::narrow_cast(x)]); + auto s = x_original[narrow(x)] - std::floor(x_original[narrow(x)]); if (cubic_coeffs.find(s) == cubic_coeffs.end()) { cubic_coeffs[s] = GetCubicCoeffs(s, cubic_coeff_a); coeff_to_1Dinterpolation_map[s] = {}; @@ -985,7 +986,7 @@ void ResizeBiCubic(int64_t batch_size, for (int64_t n = 0; n < batch_size; n++) { for (int64_t c = 0; c < num_channels; c++) { for (int64_t y = 0; y < output_height; ++y) { - auto in_y = y_original[gsl::narrow_cast(y)]; + auto in_y = y_original[narrow(y)]; // when use_extrapolation is set and original index is out of the dim range // then use extrapolation_value as the output value. @@ -1006,13 +1007,13 @@ void ResizeBiCubic(int64_t batch_size, y_coeff_sum = 0; auto& orig_y_coeffs = cubic_coeffs[in_y - y_int]; for (int64_t i = 0, y_val = y_int - 1; y_val <= y_int + 2; y_val++, i++) { - y_coeff_holder[gsl::narrow_cast(i)] = (y_val < 0 || y_val >= static_cast(input_height)) ? 0.0f : orig_y_coeffs[gsl::narrow_cast(i)]; - y_coeff_sum += y_coeff_holder[gsl::narrow_cast(i)]; + y_coeff_holder[narrow(i)] = (y_val < 0 || y_val >= static_cast(input_height)) ? 0.0f : orig_y_coeffs[narrow(i)]; + y_coeff_sum += y_coeff_holder[narrow(i)]; } } for (int64_t x = 0; x < output_width; ++x) { - auto in_x = x_original[gsl::narrow_cast(x)]; + auto in_x = x_original[narrow(x)]; // when use_extrapolation is set and original index is out of the dim range // then use extrapolation_value as the output value. @@ -1032,8 +1033,8 @@ void ResizeBiCubic(int64_t batch_size, x_coeff_sum = 0; auto& orig_x_coeff = cubic_coeffs[s_x]; for (int64_t i = 0, x_val = x_int - 1; x_val <= x_int + 2; x_val++, i++) { - x_coeff_holder[gsl::narrow_cast(i)] = (x_val < 0 || x_val >= static_cast(input_width)) ? 0.0f : orig_x_coeff[gsl::narrow_cast(i)]; - x_coeff_sum += x_coeff_holder[gsl::narrow_cast(i)]; + x_coeff_holder[narrow(i)] = (x_val < 0 || x_val >= static_cast(input_width)) ? 0.0f : orig_x_coeff[narrow(i)]; + x_coeff_sum += x_coeff_holder[narrow(i)]; } } @@ -1045,7 +1046,7 @@ void ResizeBiCubic(int64_t batch_size, auto x_interpolation_result = CubicInterpolation1D(Xdata, x_int, y_val, input_height, input_width, coeff_x, x_coeff_sum, interpolation_result_cache); - result += x_interpolation_result * coeff_y[gsl::narrow_cast(i)] / y_coeff_sum; + result += x_interpolation_result * coeff_y[narrow(i)] / y_coeff_sum; } Ydata[y * output_width + x] = static_cast(result); @@ -1092,7 +1093,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, bool no_scale = true; for (std::size_t i = 0, end = output_dims.size(); i < end; i++) { - if (no_scale && output_dims[gsl::narrow_cast(i)] != dims[gsl::narrow_cast(i)]) no_scale = false; + if (no_scale && output_dims[narrow(i)] != dims[narrow(i)]) no_scale = false; } if (no_scale) { @@ -1300,7 +1301,7 @@ Status Upsample::Compute(OpKernelContext* context) const { size_t input_rank = input_dims.size(); roi_array.resize(input_rank * 2); for (size_t i = 0; i < input_rank; ++i) { - roi_array[gsl::narrow_cast(i)] = 0; + roi_array[narrow(i)] = 0; roi_array[i + input_rank] = 1; } } @@ -1336,7 +1337,7 @@ Status Upsample::Compute(OpKernelContext* context) const { ORT_ENFORCE(sizes != nullptr && sizes->Shape().Size() != 0, "Either scales or sizes MUST be provided as input."); // When sizes input is available directly populate it into the output_dims array. - memcpy(output_dims.data(), sizes->template Data(), gsl::narrow_cast(sizes->Shape().Size())* sizeof(int64_t)); + memcpy(output_dims.data(), sizes->template Data(), SafeInt(sizes->Shape().Size())* sizeof(int64_t)); ORT_ENFORCE(X->Shape().GetDims().size() == output_dims.size(), "Resize: input tensor's rank does not match the output tensor's rank.");