make TensorImpl::data_ptr_impl() non-const and have mutable in the name (#97744)

See D44409928.

Differential Revision: [D44450468](https://our.internmc.facebook.com/intern/diff/D44450468/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97744
Approved by: https://github.com/ezyang
This commit is contained in:
mikey dagitses 2023-04-08 07:02:51 -07:00 committed by PyTorch MergeBot
parent 54b168484d
commit 9d36361601
3 changed files with 16 additions and 4 deletions

View file

@ -14,7 +14,7 @@ namespace at {
#name \
" but found ", \
scalar_type()); \
return this->unsafeGetTensorImpl()->data_ptr_impl<T>(); \
return this->unsafeGetTensorImpl()->mutable_data_ptr_impl<T>(); \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST)

View file

@ -1503,7 +1503,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
", while tensor contains ",
data_type_.name(),
". ");
return data_ptr_impl<T>();
return legacy_mutable_data_ptr_impl<T>();
}
/**
@ -1512,7 +1512,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* check has_storage() and storage_initialized().
*/
template <typename T>
inline T* data_ptr_impl() const {
inline T* mutable_data_ptr_impl() {
return legacy_mutable_data_ptr_impl<T>();
}
private:
// The real implementation of mutable_data_ptr_impl, but in a
// non-const method.
//
// TODO: move the implementation into mutable_data_ptr_impl() and
// delete this when data<T>() is no longer const.
template <typename T>
inline T* legacy_mutable_data_ptr_impl() const {
TORCH_CHECK(
has_storage(),
"Cannot access data pointer of Tensor that doesn't have storage");
@ -1525,6 +1536,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return static_cast<T*>(storage_.mutable_data()) + storage_offset_;
}
public:
/**
* Return a void* data pointer to the actual data which this tensor refers to.
*

View file

@ -59,7 +59,7 @@ std::vector<std::tuple<std::string, int64_t>> parseMethodHandle(
}
float* float_data_ptr(const at::Tensor& t) {
return t.unsafeGetTensorImpl()->data_ptr_impl<float>();
return t.data_ptr<float>();
}
} // namespace