mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20441 This op is fairly complex and the fact that it isn't formatted correctly makes things that much harder to reason about. Clean it up. Reviewed By: dreiss Differential Revision: D15220006 fbshipit-source-id: 30632d8bdbf15f96e73d8b6c96c5f29c052e6e7c
38 lines
1.1 KiB
C++
38 lines
1.1 KiB
C++
#include "caffe2/core/common_gpu.h"
|
|
#include "caffe2/core/context_gpu.h"
|
|
#include "caffe2/image/image_input_op.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <>
|
|
bool ImageInputOp<CUDAContext>::ApplyTransformOnGPU(
|
|
const std::vector<std::int64_t>& dims,
|
|
const c10::Device& type) {
|
|
// GPU transform kernel allows explicitly setting output type
|
|
if (output_type_ == TensorProto_DataType_FLOAT) {
|
|
auto* image_output =
|
|
OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(type));
|
|
TransformOnGPU<uint8_t, float, CUDAContext>(
|
|
prefetched_image_on_device_,
|
|
image_output,
|
|
mean_gpu_,
|
|
std_gpu_,
|
|
&context_);
|
|
} else if (output_type_ == TensorProto_DataType_FLOAT16) {
|
|
auto* image_output =
|
|
OperatorBase::OutputTensor(0, dims, at::dtype<at::Half>().device(type));
|
|
TransformOnGPU<uint8_t, at::Half, CUDAContext>(
|
|
prefetched_image_on_device_,
|
|
image_output,
|
|
mean_gpu_,
|
|
std_gpu_,
|
|
&context_);
|
|
} else {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp<CUDAContext>);
|
|
|
|
} // namespace caffe2
|