mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Enforce same input tensor storage in VariableType functions (#16305)
Summary: In VariableType.cpp, when a function modifies its input tensors, it should only change the input tensors' storage data in-place, and should never change the input tensors' storage pointers. This PR adds checks for this, and also fixes functions that fail this test. This is part of the Variable/Tensor merge work (https://github.com/pytorch/pytorch/issues/13638). Pull Request resolved: https://github.com/pytorch/pytorch/pull/16305 Differential Revision: D13897855 Pulled By: yf225 fbshipit-source-id: 0c4fc7eb530d30db88037b1f0981f6f8454d3b79
This commit is contained in:
parent
4b454c3bdd
commit
e2a5b203fc
12 changed files with 130 additions and 3 deletions
|
|
@ -72,6 +72,9 @@ TensorImpl* SparseTensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
|
|||
" changing dimensionality via maybe_zero_dim");
|
||||
return this;
|
||||
}
|
||||
bool SparseTensorImpl::has_storage() const {
|
||||
return false;
|
||||
}
|
||||
const Storage& SparseTensorImpl::storage() const {
|
||||
AT_ERROR("sparse tensors do not have storage");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ public:
|
|||
|
||||
int64_t dim() const override;
|
||||
TensorImpl* maybe_zero_dim(bool condition_when_zero_dim) override;
|
||||
bool has_storage() const override;
|
||||
const Storage& storage() const override;
|
||||
int64_t storage_offset() const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -184,6 +184,9 @@ class CAFFE2_API Tensor {
|
|||
ScalarType scalar_type() const {
|
||||
return typeMetaToScalarType(impl_->dtype());
|
||||
}
|
||||
bool has_storage() const {
|
||||
return defined() && impl_->has_storage();
|
||||
}
|
||||
const Storage& storage() const {
|
||||
return impl_->storage();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -175,7 +175,10 @@ std::tuple<Tensor&,Tensor&> gesv_out(Tensor& solution, Tensor& lu, const Tensor&
|
|||
AT_CHECK(self.dim() == 2 && A.dim() == 2,
|
||||
"torch.gesv() with the `out` keyword does not support batching. "
|
||||
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
|
||||
std::tie(solution, lu) = at::_gesv_helper(self, A);
|
||||
Tensor solution_tmp, lu_tmp;
|
||||
std::tie(solution_tmp, lu_tmp) = at::_gesv_helper(self, A);
|
||||
solution.resize_as_(solution_tmp).copy_(solution_tmp);
|
||||
lu.resize_as_(lu_tmp).copy_(lu_tmp);
|
||||
return std::tuple<Tensor&, Tensor&>(solution, lu);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -184,6 +184,9 @@ class CAFFE2_API Tensor {
|
|||
ScalarType scalar_type() const {
|
||||
return typeMetaToScalarType(impl_->dtype());
|
||||
}
|
||||
bool has_storage() const {
|
||||
return defined() && impl_->has_storage();
|
||||
}
|
||||
const Storage& storage() const {
|
||||
return impl_->storage();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -110,6 +110,10 @@ TensorImpl* TensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
|
|||
return this;
|
||||
}
|
||||
|
||||
bool TensorImpl::has_storage() const {
|
||||
return storage_;
|
||||
}
|
||||
|
||||
const Storage& TensorImpl::storage() const {
|
||||
return storage_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -292,7 +292,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
virtual int64_t dim() const;
|
||||
|
||||
/**
|
||||
* Return the underyling storage of a Tensor. Multiple tensors may share
|
||||
* True if this tensor has storage. See storage() for details.
|
||||
*/
|
||||
virtual bool has_storage() const;
|
||||
|
||||
/**
|
||||
* Return the underlying storage of a Tensor. Multiple tensors may share
|
||||
* a single storage. A Storage is an impoverished, Tensor-like class
|
||||
* which supports far less operations than Tensor.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ int64_t UndefinedTensorImpl::dim() const {
|
|||
AT_ERROR("dim() called on undefined Tensor");
|
||||
}
|
||||
|
||||
bool UndefinedTensorImpl::has_storage() const {
|
||||
AT_ERROR("has_storage() called on undefined Tensor");
|
||||
}
|
||||
|
||||
const Storage& UndefinedTensorImpl::storage() const {
|
||||
AT_ERROR("storage() called on undefined Tensor");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
|
|||
int64_t size(int64_t d) const override;
|
||||
int64_t stride(int64_t d) const override;
|
||||
int64_t dim() const override;
|
||||
bool has_storage() const override;
|
||||
const Storage& storage() const override;
|
||||
int64_t storage_offset() const override;
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ from .utils import CodeTemplate, nested_dict, write, uninplace_api_name
|
|||
from .gen_autograd import VIEW_FUNCTIONS
|
||||
from .gen_autograd_functions import uses_single_grad
|
||||
|
||||
|
||||
# These functions are written manually in templates/VariableType.cpp
|
||||
MANUAL_IMPLEMENTATIONS = {
|
||||
'resize_', 'resize_as_', 'detach', 'detach_', 's_copy_', '_s_copy_from'
|
||||
|
|
@ -80,6 +79,68 @@ DONT_REQUIRE_DERIVATIVE = {
|
|||
'_coalesced_',
|
||||
}
|
||||
|
||||
# NOTE [ Invariant: TensorImpl and Storage Pointer Equality ]
|
||||
#
|
||||
# When a function modifies its input tensors (via inplace or out-variants),
|
||||
# it should never change the the input tensors' underlying c10::TensorImpl pointers
|
||||
# or c10::Storage pointers.
|
||||
#
|
||||
# The following code templates implement the checks for this invariant:
|
||||
SAVE_TENSOR_STORAGE = CodeTemplate("""\
|
||||
c10::optional<Storage> ${tensor_name}_storage_saved =
|
||||
${tensor_name}.has_storage() ? c10::optional<Storage>(${tensor_name}.storage()) : c10::nullopt;
|
||||
""")
|
||||
|
||||
ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate("""\
|
||||
if (${tensor_name}_storage_saved.has_value())
|
||||
AT_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${tensor_name}.storage()));
|
||||
""")
|
||||
|
||||
SAVE_TENSORLIST_STORAGE = CodeTemplate("""\
|
||||
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
|
||||
for (Tensor tensor : ${tensorlist_name})
|
||||
${tensorlist_name}_storage_saved.push_back(
|
||||
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
|
||||
""")
|
||||
|
||||
ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate("""\
|
||||
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
||||
if (${tensorlist_name}_storage_saved[i].has_value())
|
||||
AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
|
||||
}
|
||||
""")
|
||||
|
||||
SAVE_TENSOR_IMPL = CodeTemplate("""\
|
||||
c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
|
||||
if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
|
||||
""")
|
||||
|
||||
ENFORCE_SAME_TENSOR_IMPL = CodeTemplate("""\
|
||||
if (${tensor_name}_impl_saved) AT_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
|
||||
""")
|
||||
|
||||
SAVE_TENSORLIST_IMPL = CodeTemplate("""\
|
||||
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
|
||||
for (size_t i=0; i<${tensorlist_name}.size(); i++)
|
||||
if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
|
||||
""")
|
||||
|
||||
ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate("""\
|
||||
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
||||
if (${tensorlist_name}_impl_saved[i])
|
||||
AT_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
|
||||
}
|
||||
""")
|
||||
|
||||
# The following list contains functions that we don't enforce the invariant on.
|
||||
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
|
||||
# These functions are expected to change impl or storage of input tensors
|
||||
'_th_set_', '_cudnn_rnn_flatten_weight',
|
||||
# TODO: Fix these functions to update input tensor in-place
|
||||
'tril_', 'triu_',
|
||||
}
|
||||
# END CHECKS FOR [ Invariant: TensorImpl and Storage Pointer Equality ]
|
||||
|
||||
METHOD_DECLARATION = CodeTemplate("""\
|
||||
${return_type} ${method_prefix_derived}${api_name}(${type_method_formals}) const override;
|
||||
""")
|
||||
|
|
@ -189,6 +250,12 @@ if (tracer_state) {
|
|||
}
|
||||
""")
|
||||
|
||||
RUN_ONLY_IN_DEBUG_MODE = CodeTemplate("""\
|
||||
#ifndef NDEBUG
|
||||
${statements}
|
||||
#endif
|
||||
""")
|
||||
|
||||
|
||||
FACTORY_FUNCTION_NAMES = None
|
||||
|
||||
|
|
@ -608,6 +675,29 @@ def emit_body(declaration):
|
|||
else:
|
||||
return 'as_variable({})'.format(call), []
|
||||
|
||||
def enforce_same_tensorimpl_and_storage(env, call):
|
||||
save_ptrs_stmts = []
|
||||
enforce_same_ptrs_stmts = []
|
||||
if declaration['name'] not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
|
||||
for arg in env.get('unpacked_args', []):
|
||||
simple_type = env['unpacked_args_simple_type'][arg]
|
||||
if simple_type == 'TensorList':
|
||||
save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
||||
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
||||
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
||||
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
||||
elif simple_type == 'Tensor':
|
||||
save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
|
||||
SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
|
||||
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg),
|
||||
ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)]
|
||||
assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (not save_ptrs_stmts and not enforce_same_ptrs_stmts)
|
||||
if save_ptrs_stmts and enforce_same_ptrs_stmts:
|
||||
call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
|
||||
call + \
|
||||
RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
|
||||
return call
|
||||
|
||||
def emit_call(env):
|
||||
combined = nested_dict(env, declaration)
|
||||
extra_wrapping_stmts = []
|
||||
|
|
@ -634,6 +724,7 @@ def emit_body(declaration):
|
|||
call = call + ';'
|
||||
for stmt in extra_wrapping_stmts:
|
||||
call += '\n' + stmt
|
||||
call = enforce_same_tensorimpl_and_storage(env, call)
|
||||
return call
|
||||
|
||||
def tie_return_values():
|
||||
|
|
@ -725,9 +816,11 @@ def unpack_args(env, declaration):
|
|||
|
||||
body = []
|
||||
unpacked_args = []
|
||||
unpacked_args_simple_type = {}
|
||||
for i, arg in enumerate(declaration['arguments']):
|
||||
if not requires_unpack(arg):
|
||||
unpacked_args.append(arg['name'])
|
||||
unpacked_args_simple_type[arg['name']] = arg['simple_type']
|
||||
continue
|
||||
|
||||
dynamic_type = arg['dynamic_type']
|
||||
|
|
@ -749,8 +842,10 @@ def unpack_args(env, declaration):
|
|||
body.append(UNPACK_OPTIONS.substitute(arg_name=arg['name']))
|
||||
|
||||
unpacked_args.append(arg['name'] + '_')
|
||||
unpacked_args_simple_type[arg['name'] + '_'] = arg['simple_type']
|
||||
|
||||
env['unpacked_args'] = unpacked_args
|
||||
env['unpacked_args_simple_type'] = unpacked_args_simple_type
|
||||
return body
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -92,6 +92,10 @@ void* Variable::Impl::slow_data() const {
|
|||
return data_.unsafeGetTensorImpl()->slow_data();
|
||||
}
|
||||
|
||||
bool Variable::Impl::has_storage() const {
|
||||
return data_.has_storage();
|
||||
}
|
||||
|
||||
const at::Storage& Variable::Impl::storage() const {
|
||||
return data_.storage();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -407,6 +407,7 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
|
|||
void set_storage_offset(int64_t storage_offset) override;
|
||||
|
||||
int64_t dim() const override;
|
||||
bool has_storage() const override;
|
||||
const at::Storage& storage() const override;
|
||||
void* slow_data() const override;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue