diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 385f0f8864..91603549f7 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -158,47 +158,68 @@ void upsampleBilinear( float height_scale, float width_scale, const T* Xdata, - T* Ydata) { + T* Ydata, + AllocatorPtr& alloc) { int64_t output_width = static_cast(input_width * width_scale); int64_t output_height = static_cast(input_height * height_scale); + size_t inx_buffer_size = 2 * sizeof(int64_t) * (output_height + output_width); + size_t scale_buffer_size = 2 * sizeof(float_t) * (output_height + output_width); + auto inx_scale_data_buffer = alloc->Alloc(inx_buffer_size + scale_buffer_size); + BufferUniquePtr inx_scale_data_buffer_holder(inx_scale_data_buffer, BufferDeleter(alloc)); + int64_t* inx_data = static_cast(inx_scale_data_buffer_holder.get()); + int64_t* input_width_mul_y1 = inx_data; + int64_t* input_width_mul_y2 = inx_data + output_height; + int64_t* in_x1 = inx_data + 2 * output_height; + int64_t* in_x2 = inx_data + 2 * output_height + output_width; + + float* scale_data = reinterpret_cast( in_x2 + output_width ); + float* dy1 = scale_data; + float* dy2 = scale_data + output_height; + float* dx1 = scale_data + 2 * output_height; + float* dx2 = scale_data + 2 * output_height + output_width; + + for (int64_t y = 0; y < output_height; ++y) { + float in_y = std::min(y / height_scale, static_cast(input_height - 1)); + const int64_t in_y1 = std::min(static_cast(in_y), input_height - 1); + const int64_t in_y2 = std::min(in_y1 + 1, input_height - 1); + dy1[y] = fabs(in_y - in_y1); + dy2[y] = fabs(in_y - in_y2); + if (in_y1 == in_y2) { + dy1[y] = 0.5f; + dy2[y] = 0.5f; + } + + input_width_mul_y1[y] = input_width * in_y1; + input_width_mul_y2[y] = input_width * in_y2; + } + + for (int64_t x = 0; x < output_width; ++x) { + float in_x = std::min(x / width_scale, static_cast(input_width - 1)); + in_x1[x] = std::min(static_cast(in_x), input_width - 1); + in_x2[x] = std::min(in_x1[x] + 1, input_width - 1); + + dx1[x] = std::abs(in_x - in_x1[x]); + dx2[x] = std::abs(in_x - in_x2[x]); + if (in_x1[x] == in_x2[x]) { + dx1[x] = 0.5f; + dx2[x] = 0.5f; + } + } + for (int64_t n = 0; n < batch_size; ++n) { for (int64_t c = 0; c < num_channels; ++c) { for (int64_t y = 0; y < output_height; ++y) { - float in_y = std::min(y / height_scale, static_cast(input_height - 1)); - const int64_t in_y1 = std::min(static_cast(in_y), input_height - 1); - const int64_t in_y2 = std::min(in_y1 + 1, input_height - 1); - float dy1 = fabs(in_y - in_y1); - float dy2 = fabs(in_y - in_y2); - if (in_y1 == in_y2) { - dy1 = 0.5f; - dy2 = 0.5f; - } - - const int64_t input_width_mul_y1 = input_width * in_y1; - const int64_t input_width_mul_y2 = input_width * in_y2; - for (int64_t x = 0; x < output_width; ++x) { - float in_x = std::min(x / width_scale, static_cast(input_width - 1)); - const int64_t in_x1 = std::min(static_cast(in_x), input_width - 1); - const int64_t in_x2 = std::min(in_x1 + 1, input_width - 1); + T X11 = Xdata[input_width_mul_y1[y] + in_x1[x]]; + T X21 = Xdata[input_width_mul_y1[y] + in_x2[x]]; + T X12 = Xdata[input_width_mul_y2[y] + in_x1[x]]; + T X22 = Xdata[input_width_mul_y2[y] + in_x2[x]]; - float dx1 = std::abs(in_x - in_x1); - float dx2 = std::abs(in_x - in_x2); - if (in_x1 == in_x2) { - dx1 = 0.5f; - dx2 = 0.5f; - } - - T X11 = Xdata[input_width_mul_y1 + in_x1]; - T X21 = Xdata[input_width_mul_y1 + in_x2]; - T X12 = Xdata[input_width_mul_y2 + in_x1]; - T X22 = Xdata[input_width_mul_y2 + in_x2]; - - Ydata[output_width * y + x] = static_cast(dx2 * dy2 * X11 + - dx1 * dy2 * X21 + - dx2 * dy1 * X12 + - dx1 * dy1 * X22); + Ydata[output_width * y + x] = static_cast(dx2[x] * dy2[y] * X11 + + dx1[x] * dy2[y] * X21 + + dx2[x] * dy1[y] * X12 + + dx1[x] * dy1[y] * X22); } } Xdata += input_height * input_width; @@ -235,8 +256,10 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vectorGetTempSpaceAllocator(&alloc)); upsampleBilinear(batch_size, num_channels, input_height, input_width, - scales[2], scales[3], X->template Data(), Y->template MutableData()); + scales[2], scales[3], X->template Data(), Y->template MutableData(), alloc); return Status::OK(); } default: