From 9d363616011bb1ababffe78a97eec2597145fc2f Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Sat, 8 Apr 2023 07:02:51 -0700 Subject: [PATCH] make TensorImpl::data_ptr_impl() non-const and have mutable in the name (#97744) See D44409928. Differential Revision: [D44450468](https://our.internmc.facebook.com/intern/diff/D44450468/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97744 Approved by: https://github.com/ezyang --- aten/src/ATen/templates/TensorMethods.cpp | 2 +- c10/core/TensorImpl.h | 16 ++++++++++++++-- test/cpp/jit/test_backend_compiler_lib.cpp | 2 +- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp index dd8f3c38417..9cbd55fdd11 100644 --- a/aten/src/ATen/templates/TensorMethods.cpp +++ b/aten/src/ATen/templates/TensorMethods.cpp @@ -14,7 +14,7 @@ namespace at { #name \ " but found ", \ scalar_type()); \ - return this->unsafeGetTensorImpl()->data_ptr_impl(); \ + return this->unsafeGetTensorImpl()->mutable_data_ptr_impl(); \ } AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST) diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index fa812c6c0b9..b7895a8c080 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1503,7 +1503,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ", while tensor contains ", data_type_.name(), ". "); - return data_ptr_impl(); + return legacy_mutable_data_ptr_impl(); } /** @@ -1512,7 +1512,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * check has_storage() and storage_initialized(). */ template - inline T* data_ptr_impl() const { + inline T* mutable_data_ptr_impl() { + return legacy_mutable_data_ptr_impl(); + } + + private: + // The real implementation of mutable_data_ptr_impl, but in a + // non-const method. + // + // TODO: move the implementation into mutable_data_ptr_impl() and + // delete this when data() is no longer const. + template + inline T* legacy_mutable_data_ptr_impl() const { TORCH_CHECK( has_storage(), "Cannot access data pointer of Tensor that doesn't have storage"); @@ -1525,6 +1536,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return static_cast(storage_.mutable_data()) + storage_offset_; } + public: /** * Return a void* data pointer to the actual data which this tensor refers to. * diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index baa54b0024e..93a2b3a8df2 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -59,7 +59,7 @@ std::vector> parseMethodHandle( } float* float_data_ptr(const at::Tensor& t) { - return t.unsafeGetTensorImpl()->data_ptr_impl(); + return t.data_ptr(); } } // namespace