pytorch/caffe2/image/image_input_op_gpu.cc
Ashwin Bharambe 36d3398aa5 Clang-format ImageInputOp (#20441)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20441

This op is fairly complex and the fact that it isn't formatted
correctly makes things that much harder to reason about. Clean it up.

Reviewed By: dreiss

Differential Revision: D15220006

fbshipit-source-id: 30632d8bdbf15f96e73d8b6c96c5f29c052e6e7c
2019-05-16 23:00:09 -07:00

38 lines
1.1 KiB
C++

#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/image/image_input_op.h"
namespace caffe2 {
template <>
bool ImageInputOp<CUDAContext>::ApplyTransformOnGPU(
const std::vector<std::int64_t>& dims,
const c10::Device& type) {
// GPU transform kernel allows explicitly setting output type
if (output_type_ == TensorProto_DataType_FLOAT) {
auto* image_output =
OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(type));
TransformOnGPU<uint8_t, float, CUDAContext>(
prefetched_image_on_device_,
image_output,
mean_gpu_,
std_gpu_,
&context_);
} else if (output_type_ == TensorProto_DataType_FLOAT16) {
auto* image_output =
OperatorBase::OutputTensor(0, dims, at::dtype<at::Half>().device(type));
TransformOnGPU<uint8_t, at::Half, CUDAContext>(
prefetched_image_on_device_,
image_output,
mean_gpu_,
std_gpu_,
&context_);
} else {
return false;
}
return true;
}
REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp<CUDAContext>);
} // namespace caffe2