mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
make TensorImpl::data_ptr_impl() non-const and have mutable in the name (#97744)
See D44409928. Differential Revision: [D44450468](https://our.internmc.facebook.com/intern/diff/D44450468/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97744 Approved by: https://github.com/ezyang
This commit is contained in:
parent
54b168484d
commit
9d36361601
3 changed files with 16 additions and 4 deletions
|
|
@ -14,7 +14,7 @@ namespace at {
|
|||
#name \
|
||||
" but found ", \
|
||||
scalar_type()); \
|
||||
return this->unsafeGetTensorImpl()->data_ptr_impl<T>(); \
|
||||
return this->unsafeGetTensorImpl()->mutable_data_ptr_impl<T>(); \
|
||||
}
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST)
|
||||
|
|
|
|||
|
|
@ -1503,7 +1503,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
", while tensor contains ",
|
||||
data_type_.name(),
|
||||
". ");
|
||||
return data_ptr_impl<T>();
|
||||
return legacy_mutable_data_ptr_impl<T>();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -1512,7 +1512,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
* check has_storage() and storage_initialized().
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* data_ptr_impl() const {
|
||||
inline T* mutable_data_ptr_impl() {
|
||||
return legacy_mutable_data_ptr_impl<T>();
|
||||
}
|
||||
|
||||
private:
|
||||
// The real implementation of mutable_data_ptr_impl, but in a
|
||||
// non-const method.
|
||||
//
|
||||
// TODO: move the implementation into mutable_data_ptr_impl() and
|
||||
// delete this when data<T>() is no longer const.
|
||||
template <typename T>
|
||||
inline T* legacy_mutable_data_ptr_impl() const {
|
||||
TORCH_CHECK(
|
||||
has_storage(),
|
||||
"Cannot access data pointer of Tensor that doesn't have storage");
|
||||
|
|
@ -1525,6 +1536,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return static_cast<T*>(storage_.mutable_data()) + storage_offset_;
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* Return a void* data pointer to the actual data which this tensor refers to.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ std::vector<std::tuple<std::string, int64_t>> parseMethodHandle(
|
|||
}
|
||||
|
||||
float* float_data_ptr(const at::Tensor& t) {
|
||||
return t.unsafeGetTensorImpl()->data_ptr_impl<float>();
|
||||
return t.data_ptr<float>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue