Remove Context dependency from Tensor class (#14269)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14269

Removes reference to Context proper and instead adds a bool argument for async copy (the same as `copy_`)

For CopyFrom - I haven't tweaked all callsites yet. Instead I rely on a terrible hack that pointer to context is implicitly converted to bool when passed, haha :) It's not a good code and I propose to fix it in a follow up diff (maybe using clangr tooling).

Reviewed By: ezyang

Differential Revision: D13117981

fbshipit-source-id: 7cb1dc2ba6a4c50ac26614f45ab8318ea96e3138
This commit is contained in:
Dmytro Dzhulgakov 2018-11-28 15:43:22 -08:00 committed by Facebook Github Bot
parent 0cfbbceac3
commit da9e49e586
22 changed files with 65 additions and 63 deletions

View file

@ -917,9 +917,9 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
* a tensor on CPU and then CopyFrom a CUDA tensor, that will to a
* CUDA-to-CPU transfer).
*
* If the function is invoked without `context` the copy would be synchronous
* 'async' parameter triggers async copy for CUDA tensors
*/
void CopyFrom(const TensorImpl& src, at::BaseContext* context = nullptr) {
void CopyFrom(const TensorImpl& src, bool async = false) {
AT_ASSERT(!is_variable());
AT_ASSERTM(
src.is_contiguous(),
@ -978,7 +978,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
src.device(),
new_data,
device(),
context != nullptr);
async);
}
}
}
@ -991,8 +991,10 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
* elements, in which case this tensors' capacity is grown at a factor of
* growthPct. This ensures that Extend runs on an amortized O(1) time
* complexity.
*
* This op is auto-asynchronous if the underlying device (CUDA) supports it.
*/
void Extend(int64_t num, float growthPct, at::BaseContext* context) {
void Extend(int64_t num, float growthPct) {
AT_ASSERT(sizes_.size() >= 1u);
AT_ASSERTM(num >= 0, "`num` must be non-negative for Extend");
AT_ASSERTM(
@ -1022,8 +1024,6 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
auto oldDims = sizes_;
Resize(newCapacity);
auto* newData = raw_mutable_data(data_type_);
AT_ASSERTM(
context != nullptr, "Context must be provided to Extend the tensor");
if (data_type_.copy()) {
AT_ASSERTM(
device_type() == ::at::DeviceType::CPU,

View file

@ -159,7 +159,7 @@ void ReinitializeAndCopyFrom(
Tensor* t,
at::TensorOptions options,
const Tensor& src,
BaseContext* context) {
bool async) {
auto device_type = options.device().type();
CAFFE_ENFORCE(t != nullptr, "Target tensor ptr is null.");
if (!*t || device_type != t->GetDeviceType()) {
@ -172,7 +172,7 @@ void ReinitializeAndCopyFrom(
t->dtype(),
" to: ",
src.dtype());
t->CopyFrom(src, context);
t->CopyFrom(src, async);
}
namespace {

View file

@ -97,23 +97,22 @@ class CAFFE2_API Tensor final {
return impl_.get()->GetDevice();
}
void CopyFrom(const Tensor& src, BaseContext* context = nullptr) const {
impl_.get()->CopyFrom(*src.impl_.get(), context);
void CopyFrom(const Tensor& src, bool async = false) const {
impl_.get()->CopyFrom(*src.impl_.get(), async);
}
/**
* @brief Extend the outer-most dimension of this tensor
* to dimension of `num`.
*/
void ExtendTo(int64_t num, float growthPct, BaseContext* context) const {
void ExtendTo(int64_t num, float growthPct) const {
CAFFE_ENFORCE_GE_WITH_CALLER(impl_->dim(), 1);
CAFFE_ENFORCE_GE_WITH_CALLER(growthPct, 0);
CAFFE_ENFORCE(context != nullptr, "Context must be provided.");
Extend(num - impl_->size(0), growthPct, context);
Extend(num - impl_->size(0), growthPct);
}
void Extend(int64_t num, float growthPct, BaseContext* context) const {
impl_.get()->Extend(num, growthPct, context);
void Extend(int64_t num, float growthPct) const {
impl_.get()->Extend(num, growthPct);
}
/**
@ -451,7 +450,7 @@ CAFFE2_API void ReinitializeAndCopyFrom(
Tensor* t,
at::TensorOptions options,
const Tensor& src,
BaseContext* context = nullptr);
bool async = false);
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(12, Tensor)

View file

@ -52,7 +52,7 @@ class TTPadOp final : public Operator<Context> {
int64_t padded_dim0 = (X_dim0 / scale_ + 1) * scale_;
auto dim0_diff = padded_dim0 - X_dim0;
// set growthPct to the upper bound percentage: (100 * scale_ / X_dim0)
X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0, &context_);
X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0);
auto* X_pad_data = X_pad->template mutable_data<T>();
int64_t X_size = X_dim0 * X_dim1;

View file

@ -2302,8 +2302,8 @@ class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
int csz = im_i_boxes.rows();
int cur_start_idx = out_rois->dim(0);
out_rois->Extend(csz, 50, &context_);
out_rois_probs->Extend(csz, 50, &context_);
out_rois->Extend(csz, 50);
out_rois_probs->Extend(csz, 50);
// write rois
Eigen::Map<ERArrXXf> cur_rois(

View file

@ -167,9 +167,9 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
// Write results
int cur_start_idx = out_scores->size(0);
out_scores->Extend(total_keep_count, 50, &context_);
out_boxes->Extend(total_keep_count, 50, &context_);
out_classes->Extend(total_keep_count, 50, &context_);
out_scores->Extend(total_keep_count, 50);
out_boxes->Extend(total_keep_count, 50);
out_classes->Extend(total_keep_count, 50);
int cur_out_idx = 0;
for (int j = 1; j < num_classes; j++) {
@ -202,7 +202,7 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
}
if (out_keeps) {
out_keeps->Extend(total_keep_count, 50, &context_);
out_keeps->Extend(total_keep_count, 50);
Eigen::Map<EArrXi> out_keeps_arr(
out_keeps->template mutable_data<int>() + cur_start_idx,

View file

@ -776,7 +776,7 @@ class AppendOp final : public Operator<Context> {
CAFFE_ENFORCE(a.sizes()[i] == b.sizes()[i]);
}
auto oldSize = c->numel();
c->Extend(b.sizes()[0], kDatasetGrowthPct, &context_);
c->Extend(b.sizes()[0], kDatasetGrowthPct);
auto* dst = (char*)c->raw_mutable_data() + oldSize * b.dtype().itemsize();
context_.CopyItemsSameDevice(b.dtype(), b.numel(), b.raw_data(), dst);
return true;
@ -826,7 +826,7 @@ class AtomicAppendOp final : public Operator<Context> {
continue;
}
auto oldSize = c->numel();
c->Extend(b.sizes()[0], kDatasetGrowthPct, &context_);
c->Extend(b.sizes()[0], kDatasetGrowthPct);
auto* dst = (char*)c->raw_mutable_data() + oldSize * b.dtype().itemsize();
context_.CopyItemsSameDevice(b.dtype(), b.numel(), b.raw_data(), dst);
}

View file

@ -26,7 +26,7 @@ class ExpandDimsOp : public Operator<Context> {
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
output->CopyFrom(input, &context_);
output->CopyFrom(input, true /*async*/);
if (dims_.empty()) {
return true;
}
@ -70,7 +70,7 @@ class SqueezeOp : public Operator<Context> {
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
output->CopyFrom(input, &context_);
output->CopyFrom(input, true /*async*/);
CAFFE_ENFORCE_GT(
input.dim(),

View file

@ -284,8 +284,8 @@ bool GenerateProposalsOp<CPUContext>::RunOnDevice() {
for (int i = 0; i < num_images; i++) {
roi_counts += im_boxes[i].rows();
}
out_rois->Extend(roi_counts, 50, &context_);
out_rois_probs->Extend(roi_counts, 50, &context_);
out_rois->Extend(roi_counts, 50);
out_rois_probs->Extend(roi_counts, 50);
float* out_rois_ptr = out_rois->template mutable_data<float>();
float* out_rois_probs_ptr = out_rois_probs->template mutable_data<float>();
for (int i = 0; i < num_images; i++) {

View file

@ -71,7 +71,7 @@ class LastNWindowCollectorOp : public Operator<Context> {
if (num_entries == 0) {
if (!output_initialized) {
// Get both shape and meta
output->CopyFrom(input, &context_);
output->CopyFrom(input, true /*async*/);
}
return true;
}
@ -83,7 +83,7 @@ class LastNWindowCollectorOp : public Operator<Context> {
// output_num is >= output_batch_size
if (output_num > output_batch_size) {
output->ExtendTo(output_num, 50, &context_);
output->ExtendTo(output_num, 50);
}
auto* output_data =

View file

@ -23,7 +23,7 @@ class MeanOp final : public Operator<Context> {
auto* output = Output(0);
output->ResizeLike(input0);
output->CopyFrom(input0, &context_);
output->CopyFrom(input0, true /*async*/);
if (InputSize() == 1) {
return true;
@ -102,7 +102,7 @@ class MeanGradientOp : public Operator<Context> {
for (int i = 1; i < num_inputs; i++) {
auto* cur_dX = Output(i);
cur_dX->ResizeLike(dY);
cur_dX->CopyFrom(*dX0, &context_);
cur_dX->CopyFrom(*dX0, true /*async*/);
}
return true;

View file

@ -171,7 +171,7 @@ class ONNXWhileOp final : public Operator<Context> {
scan_outputs_sizes[i],
"Size of scan output changed across iterations");
dims.insert(dims.begin(), itr);
scan_output_target->Extend(1, 100, &context_);
scan_output_target->Extend(1, 100);
int64_t timestep_size = 1;
for (const int64_t t : scan_outputs_sizes[i]) {

View file

@ -103,9 +103,9 @@ class ReservoirSamplingOp final : public Operator<Context> {
auto output_num =
std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
// output_num is >= output_batch_size
output->ExtendTo(output_num, 50, &context_);
output->ExtendTo(output_num, 50);
if (pos_to_object) {
pos_to_object->ExtendTo(output_num, 50, &context_);
pos_to_object->ExtendTo(output_num, 50);
}
auto* output_data =

View file

@ -58,7 +58,7 @@ bool RMACRegionsOp<CPUContext>::RunOnDevice() {
(l + Hd - 1 > 0) ? ((H - region_size) / (1.0 * (l + Hd - 1))) : 0;
int cur_rows = output->dim32(0);
output->Extend((l + Wd) * (l + Hd), 50, &context_);
output->Extend((l + Wd) * (l + Hd), 50);
auto* outputData = output->template mutable_data<float>() + cur_rows * 5;
for (int i = 0; i < l + Wd; ++i) {
@ -87,7 +87,7 @@ bool RMACRegionsOp<CPUContext>::RunOnDevice() {
// Replicate regions for all items in batch
int num_rois = output->dim32(0);
output->Extend((batch_size - 1) * num_rois, 50, &context_);
output->Extend((batch_size - 1) * num_rois, 50);
auto* outputData = output->template mutable_data<float>();
for (int b = 1; b < batch_size; ++b) {
// Copy all rois

View file

@ -120,9 +120,9 @@ class RemovePaddingOp final : public Operator<Context> {
bool RunOnDevice() override {
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
Output(0)->CopyFrom(Input(0), &context_);
Output(0)->CopyFrom(Input(0), true /*async*/);
if (OutputSize() == 2) {
Output(1)->CopyFrom(Input(1), &context_);
Output(1)->CopyFrom(Input(1), true /*async*/);
}
return true;
}
@ -160,9 +160,9 @@ class AddPaddingOp final : public Operator<Context> {
bool RunOnDevice() override {
if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
Output(0)->CopyFrom(Input(0), &context_);
Output(0)->CopyFrom(Input(0), true /*async*/);
if (OutputSize() == 2) {
Output(1)->CopyFrom(Input(1), &context_);
Output(1)->CopyFrom(Input(1), true /*async*/);
}
return true;
}

View file

@ -123,9 +123,9 @@ bool SliceImplGpu(
}
if (dim == -1) {
if (!backward) {
output->CopyFrom(data, context);
output->CopyFrom(data, true /*async*/);
} else {
gdata->CopyFrom(*go, context);
gdata->CopyFrom(*go, true /*async*/);
}
return true;
}

View file

@ -85,9 +85,9 @@ bool SliceImpl(
}
if (dim == -1) {
if (!backward) {
output->CopyFrom(data, context);
output->CopyFrom(data, true /*async*/);
} else {
gdata->CopyFrom(*go, context);
gdata->CopyFrom(*go, true /*async*/);
}
return true;
}

View file

@ -14,7 +14,7 @@ class StopGradientOp : public Operator<Context> {
const auto& in = Input(0);
auto* out = Output(0);
if (out != &in) {
out->CopyFrom(in, &context_);
out->CopyFrom(in, true /*async*/);
}
return true;
}

View file

@ -130,7 +130,7 @@ bool NanCheckOp<CUDAContext>::RunOnDevice() {
// This op should act as an identity matrix if we don't find any NaNs/infs.
// Copy over the data if we are not doing this in-place.
if (&X != Y) {
Y->CopyFrom(X, &context_);
Y->CopyFrom(X, true /*async*/);
}
return true;
}

View file

@ -196,7 +196,7 @@ class EnsureDenseOp final : public Operator<Context> {
// allow the output to be copied from the input
if (&input != output) {
output->ResizeLike(input);
output->CopyFrom(input, &context_);
output->CopyFrom(input, true /*async*/);
}
return true;
}
@ -257,7 +257,7 @@ class SumOp : public Operator<Context> {
auto& input0 = Input(0);
auto* output = Output(0);
if (InputSize() == 1) {
output->CopyFrom(input0, &context_);
output->CopyFrom(input0, true /*async*/);
return true;
}
output->ResizeLike(input0);

View file

@ -160,7 +160,7 @@ class SafeDequeueBlobsOp final : public Operator<Context> {
size,
" total columns");
out->Extend(in.sizes()[0], kTensorGrowthPct, &context_);
out->Extend(in.sizes()[0], kTensorGrowthPct);
auto* dst =
(char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize();
context_.template CopyItems<Context, Context>(

View file

@ -808,14 +808,17 @@ bool VideoInputOp<Context>::Prefetch() {
// prefetch function as well.
if (!std::is_same<Context, CPUContext>::value) {
if (get_rgb_) {
prefetched_clip_rgb_on_device_.CopyFrom(prefetched_clip_rgb_, &context_);
prefetched_clip_rgb_on_device_.CopyFrom(
prefetched_clip_rgb_, true /*async*/);
}
if (get_optical_flow_) {
prefetched_clip_of_on_device_.CopyFrom(prefetched_clip_of_, &context_);
prefetched_clip_of_on_device_.CopyFrom(
prefetched_clip_of_, true /*async*/);
}
prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
prefetched_label_on_device_.CopyFrom(prefetched_label_, true /*async*/);
if (get_video_id_) {
prefetched_video_id_on_device_.CopyFrom(prefetched_video_id_, &context_);
prefetched_video_id_on_device_.CopyFrom(
prefetched_video_id_, true /*async*/);
}
}
return true;
@ -828,34 +831,34 @@ bool VideoInputOp<Context>::CopyPrefetched() {
auto* clip_rgb_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
clip_rgb_output->CopyFrom(prefetched_clip_rgb_, &context_);
clip_rgb_output->CopyFrom(prefetched_clip_rgb_, true /*async*/);
} else {
clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, &context_);
clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, true /*async*/);
}
}
if (get_optical_flow_) {
auto* clip_of_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
clip_of_output->CopyFrom(prefetched_clip_of_, &context_);
clip_of_output->CopyFrom(prefetched_clip_of_, true /*async*/);
} else {
clip_of_output->CopyFrom(prefetched_clip_of_on_device_, &context_);
clip_of_output->CopyFrom(prefetched_clip_of_on_device_, true /*async*/);
}
}
auto* label_output =
OperatorBase::Output<Tensor>(index++, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
label_output->CopyFrom(prefetched_label_, &context_);
label_output->CopyFrom(prefetched_label_, true /*async*/);
} else {
label_output->CopyFrom(prefetched_label_on_device_, &context_);
label_output->CopyFrom(prefetched_label_on_device_, true /*async*/);
}
if (get_video_id_) {
auto* video_id_output =
OperatorBase::Output<Tensor>(index, Context::GetDeviceType());
if (std::is_same<Context, CPUContext>::value) {
video_id_output->CopyFrom(prefetched_video_id_, &context_);
video_id_output->CopyFrom(prefetched_video_id_, true /*async*/);
} else {
video_id_output->CopyFrom(prefetched_video_id_on_device_, &context_);
video_id_output->CopyFrom(prefetched_video_id_on_device_, true /*async*/);
}
}
return true;