diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 6cbacc9b4d..c1b80f63ed 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -66,30 +66,121 @@ Status UpsampleNearest(const T* input, return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value is nullptr"); if (input_shape.NumDimensions() != output_shape.NumDimensions()) return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value's dimension mismatch"); - auto n_dim = input_shape.NumDimensions(); - if (scales.size() == 4 && scales[0] == 1 && scales[1] == 1 && scales[2] == 2 && scales[3] == 2) { - UpsampleNearest2x(input_shape[0], input_shape[1], input_shape[2], input_shape[3], input, output); - } else { - for (size_t i = 0, size = output_shape.Size(); i < size; i++) { - size_t old_idx = 0; - size_t cur_idx = i; - - int64_t base = 1; - for (auto j = static_cast(n_dim - 1); j >= 0; j--) { - auto tmp = cur_idx % output_shape[j]; - - if (scales[j] < 1) { //downsample - old_idx += (std::min(static_cast(std::ceil(tmp / scales[j])), input_shape[j] - 1)) * base; - } else { //upsample - old_idx += (std::min(static_cast(tmp / scales[j]), input_shape[j] - 1)) * base; - } - base *= input_shape[j]; - cur_idx /= output_shape[j]; - } - - output[i] = input[old_idx]; - } + if (input_shape.NumDimensions() == 0) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Upsample: input shape needs to be at least a single dimension."); } + + int64_t n_dim = static_cast(input_shape.NumDimensions()); + + std::vector input_dim_counters(n_dim); + std::vector input_dim_factor(n_dim); + input_dim_factor[n_dim - 1] = 1; // initialize dimension factor + for (int64_t dim_idx = n_dim - 2; dim_idx >= 0; dim_idx--) { + input_dim_factor[dim_idx] = input_dim_factor[dim_idx + 1] * input_shape[dim_idx + 1]; + } + + int64_t output_idx = 0; + int64_t input_idx = 0; + +#define OneDemensionProcessor(dim_inx) \ + int64_t input_dim##dim_inx##_inx = \ + static_cast(scales[dim_inx] < 1 ? std::ceil(output_dim##dim_inx##_inx / scales[dim_inx]) : output_dim##dim_inx##_inx / scales[dim_inx]); \ + if (input_dim##dim_inx##_inx > input_shape[dim_inx] - 1) input_dim##dim_inx##_inx = input_shape[dim_inx] - 1; \ + if (input_dim##dim_inx##_inx != input_dim_counters[dim_inx]) { \ + input_idx += (input_dim##dim_inx##_inx - input_dim_counters[dim_inx]) * input_dim_factor[dim_inx]; \ + input_dim_counters[dim_inx] = input_dim##dim_inx##_inx; \ + } + + if (n_dim == 1) { + for (int64_t output_dim0_inx = 0; output_dim0_inx < output_shape[0]; output_dim0_inx++) { + OneDemensionProcessor(0); + output[output_idx++] = input[input_idx]; + } + return Status::OK(); + } + + if (n_dim == 2) { + for (int64_t output_dim0_inx = 0; output_dim0_inx < output_shape[0]; output_dim0_inx++) { + OneDemensionProcessor(0); + for (int64_t output_dim1_inx = 0; output_dim1_inx < output_shape[1]; output_dim1_inx++) { + OneDemensionProcessor(1); + output[output_idx++] = input[input_idx]; + } + } + return Status::OK(); + } + + if (n_dim == 3) { + for (int64_t output_dim0_inx = 0; output_dim0_inx < output_shape[0]; output_dim0_inx++) { + OneDemensionProcessor(0); + for (int64_t output_dim1_inx = 0; output_dim1_inx < output_shape[1]; output_dim1_inx++) { + OneDemensionProcessor(1); + for (int64_t output_dim2_inx = 0; output_dim2_inx < output_shape[2]; output_dim2_inx++) { + OneDemensionProcessor(2); + output[output_idx++] = input[input_idx]; + } + } + } + return Status::OK(); + } + + if (n_dim == 4) { + if (scales[0] == 1 && scales[1] == 1 && scales[2] == 2 && scales[3] == 2) { + UpsampleNearest2x(input_shape[0], input_shape[1], input_shape[2], input_shape[3], input, output); + return Status::OK(); + } + for (int64_t output_dim0_inx = 0; output_dim0_inx < output_shape[0]; output_dim0_inx++) { + OneDemensionProcessor(0); + for (int64_t output_dim1_inx = 0; output_dim1_inx < output_shape[1]; output_dim1_inx++) { + OneDemensionProcessor(1); + for (int64_t output_dim2_inx = 0; output_dim2_inx < output_shape[2]; output_dim2_inx++) { + OneDemensionProcessor(2); + for (int64_t output_dim3_inx = 0; output_dim3_inx < output_shape[3]; output_dim3_inx++) { + OneDemensionProcessor(3); + output[output_idx++] = input[input_idx]; + } + } + } + } + return Status::OK(); + } + +#undef OneDemensionProcessor + + std::vector output_dim_counter(n_dim); + output_dim_counter[n_dim - 1] = -1; // initialize dimension counter + + for (; output_idx < output_shape.Size(); output_idx++) { + for (int64_t dim_idx = n_dim - 1; dim_idx >= 0; dim_idx--) { + if (++output_dim_counter[dim_idx] < output_shape[dim_idx]) { + int64_t current_input_dim_counter = 0; + if (scales[dim_idx] < 1) //downsample + { + current_input_dim_counter = static_cast(std::ceil(output_dim_counter[dim_idx] / scales[dim_idx])); + } else //upsample + { + current_input_dim_counter = static_cast(output_dim_counter[dim_idx] / scales[dim_idx]); + } + + if (current_input_dim_counter >= input_shape[dim_idx] - 1) + current_input_dim_counter = input_shape[dim_idx] - 1; + + if (current_input_dim_counter != input_dim_counters[dim_idx]) { + input_idx += (current_input_dim_counter - input_dim_counters[dim_idx]) * input_dim_factor[dim_idx]; + input_dim_counters[dim_idx] = current_input_dim_counter; + } + break; + } else { + output_dim_counter[dim_idx] = 0; + input_idx += (0 - input_dim_counters[dim_idx]) * input_dim_factor[dim_idx]; + input_dim_counters[dim_idx] = 0; + } + } + + output[output_idx] = input[input_idx]; + } + return Status::OK(); }