From ee0143bf6524e5ad7d7fe8292887981ad8dadd4c Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Tue, 11 Apr 2023 23:08:33 -0700 Subject: [PATCH] distinguish mutability of TensorImpl::data() (#98719) There already is a mutable_data() with different semantics, so we introduce new names: TensorImpl::(mutable_)?data_dtype_initialized(). Differential Revision: [D44824778](https://our.internmc.facebook.com/intern/diff/D44824778/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/98719 Approved by: https://github.com/ezyang --- aten/src/ATen/cuda/CUDAGeneratorImpl.cpp | 2 +- aten/src/ATen/mps/MPSGeneratorImpl.mm | 2 +- c10/core/TensorImpl.h | 53 ++++++++++++++++++------ caffe2/core/tensor.h | 2 +- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index f0973c6dd75..ecba45c1904 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -189,7 +189,7 @@ void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { } uint64_t input_seed; - auto new_rng_state = new_state.data(); + auto new_rng_state = new_state.data_dtype_initialized(); memcpy(&input_seed, new_rng_state, seed_size); this->set_current_seed(input_seed); int64_t philox_offset = 0; diff --git a/aten/src/ATen/mps/MPSGeneratorImpl.mm b/aten/src/ATen/mps/MPSGeneratorImpl.mm index ed7be96c8c7..6b95204adca 100644 --- a/aten/src/ATen/mps/MPSGeneratorImpl.mm +++ b/aten/src/ATen/mps/MPSGeneratorImpl.mm @@ -82,7 +82,7 @@ void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) { TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); uint64_t input_seed = default_rng_seed_val; - auto new_rng_state = new_state.data(); + auto new_rng_state = new_state.data_dtype_initialized(); memcpy(&input_seed, new_rng_state + states_size, seed_size); this->set_current_seed(input_seed); // state.data must be copied after input_seed to not reset the state in set_current_seed() diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 92ff75662c5..b697b553807 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1496,17 +1496,46 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * for you; this class is available from 'Tensor'. */ template - inline T* data() const { + const T* data_dtype_initialized() const { + return data_dtype_initialized_impl( + [this] { return static_cast(storage_.data()); }); + } + + /** + * Return a mutable typed data pointer to the actual data which this + * tensor refers to. This checks that the requested type (from the + * template parameter) matches the internal type of the tensor. + * + * It is invalid to call data() on a dtype-uninitialized tensor, even if + * the size is 0. + * + * WARNING: If a tensor is not contiguous, you MUST use strides when + * performing index calculations to determine the location of elements in + * the tensor. We recommend using 'TensorAccessor' to handle this computation + * for you; this class is available from 'Tensor'. + */ + template + T* mutable_data_dtype_initialized() { + return data_dtype_initialized_impl( + [this] { return static_cast(storage_.mutable_data()); }); + } + + private: + // Shared implementation of data_dtype_initialized() and + // mutable_data_dtype_initialized(). + template + T* data_dtype_initialized_impl(const Func& get_data) const { TORCH_CHECK( - data_type_.Match(), + data_type_.Match>(), "Tensor type mismatch, caller expects elements to be ", - caffe2::TypeMeta::TypeName(), + caffe2::TypeMeta::TypeName>(), ", while tensor contains ", data_type_.name(), ". "); - return legacy_mutable_data_ptr_impl(); + return data_ptr_impl_impl(get_data); } + public: /** * More efficient helper for Tensor::data_ptr(). Like data(), but * does not do a type check. Unlike the untemplated data(), does @@ -1514,17 +1543,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ template inline T* mutable_data_ptr_impl() { - return legacy_mutable_data_ptr_impl(); + return data_ptr_impl_impl( + [this] { return static_cast(storage_.mutable_data()); }); } private: - // The real implementation of mutable_data_ptr_impl, but in a - // non-const method. - // - // TODO: move the implementation into mutable_data_ptr_impl() and - // delete this when data() is no longer const. - template - inline T* legacy_mutable_data_ptr_impl() const { + // Shared implementation of mutable_data_ptr_impl() and the future + // mutable_data_ptr_impl(). + template + T* data_ptr_impl_impl(const Func& get_data) const { TORCH_CHECK( has_storage(), "Cannot access data pointer of Tensor that doesn't have storage"); @@ -1534,7 +1561,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { "Caffe2 uses a lazy allocation, so you will need to call " "mutable_data() or raw_mutable_data() to actually allocate memory."); // Caller does the type check. - return static_cast(storage_.mutable_data()) + storage_offset_; + return get_data() + storage_offset_; } public: diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 1462178af0f..1726027d5ed 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -356,7 +356,7 @@ class TORCH_API Tensor final { template inline T* data() const { - return impl_.get()->data(); + return impl_.get()->mutable_data_dtype_initialized(); } inline void* raw_mutable_data(const TypeMeta meta) const {