Enforce same input tensor storage in VariableType functions (#16305)

Summary:
In VariableType.cpp, when a function modifies its input tensors, it should only change the input tensors' storage data in-place, and should never change the input tensors' storage pointers. This PR adds checks for this, and also fixes functions that fail this test.

This is part of the Variable/Tensor merge work (https://github.com/pytorch/pytorch/issues/13638).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16305

Differential Revision: D13897855

Pulled By: yf225

fbshipit-source-id: 0c4fc7eb530d30db88037b1f0981f6f8454d3b79
This commit is contained in:
Will Feng 2019-02-11 12:48:17 -08:00 committed by Facebook Github Bot
parent 4b454c3bdd
commit e2a5b203fc
12 changed files with 130 additions and 3 deletions

View file

@ -72,6 +72,9 @@ TensorImpl* SparseTensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
" changing dimensionality via maybe_zero_dim");
return this;
}
bool SparseTensorImpl::has_storage() const {
return false;
}
const Storage& SparseTensorImpl::storage() const {
AT_ERROR("sparse tensors do not have storage");
}

View file

@ -51,6 +51,7 @@ public:
int64_t dim() const override;
TensorImpl* maybe_zero_dim(bool condition_when_zero_dim) override;
bool has_storage() const override;
const Storage& storage() const override;
int64_t storage_offset() const override;

View file

@ -184,6 +184,9 @@ class CAFFE2_API Tensor {
ScalarType scalar_type() const {
return typeMetaToScalarType(impl_->dtype());
}
bool has_storage() const {
return defined() && impl_->has_storage();
}
const Storage& storage() const {
return impl_->storage();
}

View file

@ -175,7 +175,10 @@ std::tuple<Tensor&,Tensor&> gesv_out(Tensor& solution, Tensor& lu, const Tensor&
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.gesv() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
std::tie(solution, lu) = at::_gesv_helper(self, A);
Tensor solution_tmp, lu_tmp;
std::tie(solution_tmp, lu_tmp) = at::_gesv_helper(self, A);
solution.resize_as_(solution_tmp).copy_(solution_tmp);
lu.resize_as_(lu_tmp).copy_(lu_tmp);
return std::tuple<Tensor&, Tensor&>(solution, lu);
}

View file

@ -184,6 +184,9 @@ class CAFFE2_API Tensor {
ScalarType scalar_type() const {
return typeMetaToScalarType(impl_->dtype());
}
bool has_storage() const {
return defined() && impl_->has_storage();
}
const Storage& storage() const {
return impl_->storage();
}

View file

@ -110,6 +110,10 @@ TensorImpl* TensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
return this;
}
bool TensorImpl::has_storage() const {
return storage_;
}
const Storage& TensorImpl::storage() const {
return storage_;
}

View file

@ -292,7 +292,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
virtual int64_t dim() const;
/**
* Return the underyling storage of a Tensor. Multiple tensors may share
* True if this tensor has storage. See storage() for details.
*/
virtual bool has_storage() const;
/**
* Return the underlying storage of a Tensor. Multiple tensors may share
* a single storage. A Storage is an impoverished, Tensor-like class
* which supports far less operations than Tensor.
*

View file

@ -24,6 +24,10 @@ int64_t UndefinedTensorImpl::dim() const {
AT_ERROR("dim() called on undefined Tensor");
}
bool UndefinedTensorImpl::has_storage() const {
AT_ERROR("has_storage() called on undefined Tensor");
}
const Storage& UndefinedTensorImpl::storage() const {
AT_ERROR("storage() called on undefined Tensor");
}

View file

@ -22,6 +22,7 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
int64_t size(int64_t d) const override;
int64_t stride(int64_t d) const override;
int64_t dim() const override;
bool has_storage() const override;
const Storage& storage() const override;
int64_t storage_offset() const override;
private:

View file

@ -29,7 +29,6 @@ from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
from .gen_autograd import VIEW_FUNCTIONS
from .gen_autograd_functions import uses_single_grad
# These functions are written manually in templates/VariableType.cpp
MANUAL_IMPLEMENTATIONS = {
'resize_', 'resize_as_', 'detach', 'detach_', 's_copy_', '_s_copy_from'
@ -80,6 +79,68 @@ DONT_REQUIRE_DERIVATIVE = {
'_coalesced_',
}
# NOTE [ Invariant: TensorImpl and Storage Pointer Equality ]
#
# When a function modifies its input tensors (via inplace or out-variants),
# it should never change the the input tensors' underlying c10::TensorImpl pointers
# or c10::Storage pointers.
#
# The following code templates implement the checks for this invariant:
SAVE_TENSOR_STORAGE = CodeTemplate("""\
c10::optional<Storage> ${tensor_name}_storage_saved =
${tensor_name}.has_storage() ? c10::optional<Storage>(${tensor_name}.storage()) : c10::nullopt;
""")
ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate("""\
if (${tensor_name}_storage_saved.has_value())
AT_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${tensor_name}.storage()));
""")
SAVE_TENSORLIST_STORAGE = CodeTemplate("""\
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
for (Tensor tensor : ${tensorlist_name})
${tensorlist_name}_storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
""")
ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate("""\
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
if (${tensorlist_name}_storage_saved[i].has_value())
AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
}
""")
SAVE_TENSOR_IMPL = CodeTemplate("""\
c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
""")
ENFORCE_SAME_TENSOR_IMPL = CodeTemplate("""\
if (${tensor_name}_impl_saved) AT_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
""")
SAVE_TENSORLIST_IMPL = CodeTemplate("""\
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
for (size_t i=0; i<${tensorlist_name}.size(); i++)
if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
""")
ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate("""\
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
if (${tensorlist_name}_impl_saved[i])
AT_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
}
""")
# The following list contains functions that we don't enforce the invariant on.
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
# These functions are expected to change impl or storage of input tensors
'_th_set_', '_cudnn_rnn_flatten_weight',
# TODO: Fix these functions to update input tensor in-place
'tril_', 'triu_',
}
# END CHECKS FOR [ Invariant: TensorImpl and Storage Pointer Equality ]
METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
""")
@ -189,6 +250,12 @@ if (tracer_state) {
}
""")
RUN_ONLY_IN_DEBUG_MODE = CodeTemplate("""\
#ifndef NDEBUG
${statements}
#endif
""")
FACTORY_FUNCTION_NAMES = None
@ -608,6 +675,29 @@ def emit_body(declaration):
else:
return 'as_variable({})'.format(call), []
def enforce_same_tensorimpl_and_storage(env, call):
save_ptrs_stmts = []
enforce_same_ptrs_stmts = []
if declaration['name'] not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
for arg in env.get('unpacked_args', []):
simple_type = env['unpacked_args_simple_type'][arg]
if simple_type == 'TensorList':
save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
elif simple_type == 'Tensor':
save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg),
ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)]
assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (not save_ptrs_stmts and not enforce_same_ptrs_stmts)
if save_ptrs_stmts and enforce_same_ptrs_stmts:
call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
call + \
RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
return call
def emit_call(env):
combined = nested_dict(env, declaration)
extra_wrapping_stmts = []
@ -634,6 +724,7 @@ def emit_body(declaration):
call = call + ';'
for stmt in extra_wrapping_stmts:
call += '\n' + stmt
call = enforce_same_tensorimpl_and_storage(env, call)
return call
def tie_return_values():
@ -725,9 +816,11 @@ def unpack_args(env, declaration):
body = []
unpacked_args = []
unpacked_args_simple_type = {}
for i, arg in enumerate(declaration['arguments']):
if not requires_unpack(arg):
unpacked_args.append(arg['name'])
unpacked_args_simple_type[arg['name']] = arg['simple_type']
continue
dynamic_type = arg['dynamic_type']
@ -749,8 +842,10 @@ def unpack_args(env, declaration):
body.append(UNPACK_OPTIONS.substitute(arg_name=arg['name']))
unpacked_args.append(arg['name'] + '_')
unpacked_args_simple_type[arg['name'] + '_'] = arg['simple_type']
env['unpacked_args'] = unpacked_args
env['unpacked_args_simple_type'] = unpacked_args_simple_type
return body

View file

@ -92,6 +92,10 @@ void* Variable::Impl::slow_data() const {
return data_.unsafeGetTensorImpl()->slow_data();
}
bool Variable::Impl::has_storage() const {
return data_.has_storage();
}
const at::Storage& Variable::Impl::storage() const {
return data_.storage();
}

View file

@ -407,6 +407,7 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
void set_storage_offset(int64_t storage_offset) override;
int64_t dim() const override;
bool has_storage() const override;
const at::Storage& storage() const override;
void* slow_data() const override;