pytorch/caffe2/image/transform_gpu.cu
Jerry Zhang 0f59dcb317 Remove partially initialized Tensor + CopyFrom (#13629)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13629

Previously we have a Tensor which has a initialized storage(therefore a known device_type) and
then we'll call CopyFrom on it to initialize the sizes and data.

We want to eliminate partially initialized Tensor by replacing the pattern of calling CopyFrom with a partially initialized Tensor with either splitting that to undefined Tensor + initialization API(1)(3) or combine all the initialization in the same step(2).

1. member variable initialization + CopyFrom
Previously we have a tensor that is initialized with device_type, and then use CopyFrom to populate the content, now we remove the partial initialization by make the original member variable an undefined Tensor and use ReinitializeFrom to copy from another Tensor.

2. Output + CopyFrom
Previously, we first get a tensor with device_type, and then CopyFrom another Tensor,
We changed it two combining these two operations into OperatorBase::OutputTensor.

3. Output + custom functions
Example can be found in TransformGPU function.
In this case we move the part that initializes the tensor outside of the function, and do that explicitly outside so that we could reuse the Output functions to make a fully initialized Tensor.

Note that to keep the original semantics, both of the APIs has a caching effect based on device_type, which means we only create a Tensor object when device_type does not match or the Tensor is undefined, otherwise, we will reuse the original Tensor object.

Reviewed By: dzhulgakov

Differential Revision: D12848855

fbshipit-source-id: 37bb4ddc1698ebea533b73006eeb1218faa8ddf8
2018-11-07 11:31:03 -08:00

84 lines
2 KiB
Text

#include "caffe2/core/context_gpu.h"
#include "caffe2/image/transform_gpu.h"
#include "caffe2/utils/conversions.h"
/**
*
* Copyright (c) 2016, NVIDIA CORPORATION, All rights reserved
* Distributed under 2-clause BSD license; see accompanying LICENSE file
*
**/
namespace caffe2 {
namespace {
// input in (int8, NHWC), output in (fp32, NCHW)
template <typename In, typename Out>
__global__ void transform_kernel(
const int N,
const int C,
const int H,
const int W,
const float* mean,
const float* std,
const In* in,
Out* out) {
const int n = blockIdx.x;
const int nStride = C*H*W;
// pointers to data for this image
const In* input_ptr = &in[n*nStride];
Out* output_ptr = &out[n*nStride];
// either read or write uncoalesced - try reading
for (int c=0; c < C; ++c) {
for (int h=threadIdx.y; h < H; h += blockDim.y) {
for (int w=threadIdx.x; w < W; w += blockDim.x) {
int in_idx = c + C*w + C*W*h; // HWC
int out_idx = c*H*W + h*W + w; // CHW
output_ptr[out_idx] = convert::To<float,Out>(
(convert::To<In,float>(input_ptr[in_idx])-mean[c]) * std[c]);
}
}
}
}
}
template <typename T_IN, typename T_OUT, class Context>
bool TransformOnGPU(
Tensor& X,
Tensor* Y,
Tensor& mean,
Tensor& std,
Context* context) {
const int N = X.dim32(0), C = X.dim32(3), H = X.dim32(1), W = X.dim32(2);
auto* input_data = X.template data<T_IN>();
auto* output_data = Y->template mutable_data<T_OUT>();
transform_kernel<
T_IN, T_OUT><<<N, dim3(16, 16), 0, context->cuda_stream()>>>(
N, C, H, W, mean.template data<float>(), std.template data<float>(),
input_data, output_data);
return true;
};
template bool TransformOnGPU<uint8_t, float, CUDAContext>(
Tensor& X,
Tensor* Y,
Tensor& mean,
Tensor& std,
CUDAContext* context);
template bool TransformOnGPU<uint8_t, at::Half, CUDAContext>(
Tensor& X,
Tensor* Y,
Tensor& mean,
Tensor& std,
CUDAContext* context);
} // namespace caffe2