diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index 2e8d51cbade..df43a796acc 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -40,14 +40,6 @@ void warnDeprecatedDataPtr() { "isinstance(tensor, FakeTensor).") } -[[noreturn]] void StorageImpl::throw_data_ptr_access_error() const { - if (extra_meta_ && extra_meta_->custom_data_ptr_error_msg_) { - // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - TORCH_CHECK(false, *extra_meta_->custom_data_ptr_error_msg_); - } - TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); -} - void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) { // Allowlist verification. // Only if the devicetype is in the allowlist, diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index 255a81855c4..abe6218fbc9 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -16,22 +16,9 @@ namespace c10 { -[[noreturn]] C10_API void throwNullDataPtrError(); +C10_API void throwNullDataPtrError(); C10_API void warnDeprecatedDataPtr(); -// Used in StorageImpl to store extra metadata. -// Currently used only for storing a custom error message -// used when throwing an exception when data_ptr is accessed. -struct C10_API StorageExtraMeta { - c10::optional custom_data_ptr_error_msg_ = c10::nullopt; - StorageExtraMeta() = default; - StorageExtraMeta(const StorageExtraMeta& other) { - if (other.custom_data_ptr_error_msg_) { - custom_data_ptr_error_msg_ = other.custom_data_ptr_error_msg_; - } - } -}; - // A storage represents the underlying backing data buffer for a // tensor. This concept was inherited from the original Torch7 // codebase; we'd kind of like to get rid of the concept @@ -136,17 +123,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } const at::DataPtr& data_ptr() const { - if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) { - throw_data_ptr_access_error(); - } return data_ptr_; } at::DataPtr& mutable_data_ptr() { - if (C10_UNLIKELY(has_mutable_data_ptr_check_)) { - if (throw_on_immutable_data_ptr_) { - throw_data_ptr_access_error(); - } + if (C10_UNLIKELY(has_data_ptr_check_)) { if (throw_on_mutable_data_ptr_) { throwNullDataPtrError(); } @@ -177,17 +158,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } const void* data() const { - if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) { - throw_data_ptr_access_error(); - } return data_ptr_.get(); } void* mutable_data() { - if (C10_UNLIKELY(has_mutable_data_ptr_check_)) { - if (throw_on_immutable_data_ptr_) { - throw_data_ptr_access_error(); - } + if (C10_UNLIKELY(has_data_ptr_check_)) { if (throw_on_mutable_data_ptr_) { throwNullDataPtrError(); } @@ -273,22 +248,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { return &pyobj_slot_; } - StorageExtraMeta& get_extra_meta() { - if (!extra_meta_) { - extra_meta_ = std::make_unique(); - } - return *extra_meta_; - } - - [[noreturn]] void throw_data_ptr_access_error() const; - - void release_data_and_set_meta_custom_data_ptr_error_msg_( - c10::optional s) { - throw_on_immutable_data_ptr_ = true; - get_extra_meta().custom_data_ptr_error_msg_ = std::move(s); - refresh_has_data_ptr_check(); - } - void set_throw_on_mutable_data_ptr() { throw_on_mutable_data_ptr_ = true; refresh_has_data_ptr_check(); @@ -314,8 +273,8 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { private: void refresh_has_data_ptr_check() { - has_mutable_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ || - warn_deprecated_on_mutable_data_ptr_ || throw_on_immutable_data_ptr_; + has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ || + warn_deprecated_on_mutable_data_ptr_; } inline bool is_cow() const { @@ -339,16 +298,13 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { // All special checks in data/data_ptr calls are guarded behind this single // boolean. This is for performance: .data/.data_ptr calls are commonly in the // hot-path. - bool has_mutable_data_ptr_check_ = false; + bool has_data_ptr_check_ = false; // If we should throw when mutable_data_ptr() or mutable_data() is called. bool throw_on_mutable_data_ptr_ = false; - // If we should throw when data_ptr() or data() is called. - bool throw_on_immutable_data_ptr_ = false; // If we warn when mutable_data_ptr() or mutable_data() is called. bool warn_deprecated_on_mutable_data_ptr_ = false; Allocator* allocator_; impl::PyObjectSlot pyobj_slot_; - std::unique_ptr extra_meta_ = nullptr; }; // Declare StorageImpl create function pointer types. diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index da1473866c7..417e120d5a0 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -882,18 +882,15 @@ if HAS_CUDA and not TEST_WITH_ASAN: def foo2(x): return x[2:] + x = torch.rand([10, 10], device="cuda", requires_grad=True) param_c = cdata(m.weight) for _ in range(3): - x = torch.rand([10, 10], device="cuda", requires_grad=True) - torch.compiler.cudagraph_mark_step_begin() out1, alias_1, alias_2 = foo(m, x) self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1) out2 = foo2(out1) out2.sum().backward() self.assertEqual(cdata(out1), cdata(out2)) - m.weight.grad = None - m.bias.grad = None node = self.curr_node() first_node = next(node._path_from_root) @@ -1498,37 +1495,12 @@ if HAS_CUDA and not TEST_WITH_ASAN: out = foo(inp) out2 = foo(inp) - with self.assertRaisesRegex( - Exception, "overwritten by a subsequent forward run." - ): + with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."): out + out foo(inp) - with self.assertRaisesRegex( - Exception, "overwritten by a subsequent forward run." - ): - out2 + out2 - - def test_error_on_dealloc_use2(self): - @torch.compile() - def foo(x): - return x * x * x - - inp = torch.rand([4], device="cuda") - out = foo(inp).detach() - out2 = foo(inp).detach() - - with self.assertRaisesRegex( - Exception, "overwritten by a subsequent forward run." - ): - out + out - - foo(inp) - - with self.assertRaisesRegex( - Exception, "overwritten by a subsequent forward run." - ): + with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."): out2 + out2 @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") @@ -1555,7 +1527,6 @@ if HAS_CUDA and not TEST_WITH_ASAN: streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()} for _ in range(4): foo(inp).sum().backward() - inp.grad = None streams = { seg["stream"] for seg in get_all_cudagraph_segments() @@ -1643,7 +1614,6 @@ if HAS_CUDA and not TEST_WITH_ASAN: out2.sum().backward() self.assertFalse(self.get_manager().running_forwards_with_pending_backwards) - ones.grad = None del out del out2 @@ -2025,7 +1995,6 @@ if HAS_CUDA and not TEST_WITH_ASAN: fn_compiled = torch.compile(Foo(), mode="reduce-overhead") for _ in range(3): fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() - fn_compiled.param.grad = None # Change static tensor address fn_compiled.param.data = torch.rand([2, 2], device="cuda") @@ -2062,13 +2031,11 @@ if HAS_CUDA and not TEST_WITH_ASAN: fn_compiled = torch.compile(Foo(), mode="reduce-overhead") for _ in range(3): fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() - fn_compiled.param.grad = None for _ in range(5): # Change static tensor address fn_compiled.param.data = torch.rand([2, 2], device="cuda") fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() - fn_compiled.param.grad = None FileCheck().check_count( "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) " diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index a19a9afa26b..db2e65e0622 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1854,7 +1854,6 @@ def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Option def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ... def _storage_Use_Count(storage_ptr: _int) -> _int: ... def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... -def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ... def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ... def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ... diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 6ed09252b27..2020d05b62b 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -1862,24 +1862,22 @@ class CUDAGraphTreeManager: # mod2(mod1(x)).sum().backward() self.running_forwards_with_pending_backwards = False - self.mode: Optional[CompilationMode] = None def run(self, new_inputs: List[Tensor], function_id: FunctionID): assert self.graph is not None, "Running CUDAGraph after shutdown" - self.mode = self.id_to_mode[function_id] out = self._run(new_inputs, function_id) # The forwards are only pending following invocation, not before - if self.mode == CompilationMode.FORWARD: + mode = self.id_to_mode[function_id] + if mode == CompilationMode.FORWARD: self.running_forwards_with_pending_backwards = True - elif self.mode == CompilationMode.BACKWARD: + elif mode == CompilationMode.BACKWARD: self.running_forwards_with_pending_backwards = False return out def set_to_running_backward(self): self.running_forwards_with_pending_backwards = False - self.mode = CompilationMode.BACKWARD def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]: return ( @@ -2290,18 +2288,9 @@ class CUDAGraphTreeManager: def dealloc_current_path_weakrefs(self): # TODO: we could also allow the these weak refs to continue to be allocated, # but that adds some complications. - run_type = "backward" if self.mode == CompilationMode.BACKWARD else "forward" - - stack_map = {} for node in self.current_node._path_from_root: - assert ( - len(node.tensor_weakrefs) - == len(node.stack_traces) - == len(node.outputs_weakrefs) - ) - for t, stack_trace, stor_ref in zip( - node.tensor_weakrefs, node.stack_traces, node.outputs_weakrefs - ): + assert len(node.tensor_weakrefs) == len(node.stack_traces) + for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): ten = None if t is None else t() if ten is None: continue @@ -2311,17 +2300,9 @@ class CUDAGraphTreeManager: if stack_trace else "[Could not find stack trace]" ) - - if stor_ref is not None and stor_ref() is not None: - stack_map[stor_ref.data_ptr()] = stack_trace - - ten = None if t is None else t() - if ten is None: - continue - msg = ( - "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent " - f"{run_type} run. Stack trace: {stack_trace}. " + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " "To prevent overwriting, clone the tensor outside of torch.compile() " "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." ) @@ -2331,17 +2312,7 @@ class CUDAGraphTreeManager: for storage_ref in self.current_node.path_live_weakrefs(): if storage_ref() and storage_ref.data_ptr() not in deleted: deleted.add(storage_ref.data_ptr()) - stack_trace = stack_map.get( - storage_ref.data_ptr(), "[Could not find stack trace]" - ) - msg = ( - "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent " - f"{run_type} run. Stack trace: {stack_trace}. " - "To prevent overwriting, clone the tensor outside of torch.compile() " - "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." - ) torch._C._free_And_Remove_DeleterFn(storage_ref()) - torch._C._set_storage_data_ptr_access_error_msg(storage_ref(), msg) def clear_current_path_state_and_set_to_none(self): self.current_node.clear_path_state() diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index f8154570d89..765647f3a9d 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1214,13 +1214,6 @@ static void registerCudaPluggableAllocator(PyObject* module) { ->release_storage_and_set_meta_custom_data_ptr_error_msg_(s); }); - m.def( - "_set_storage_data_ptr_access_error_msg", - [](size_t storage_impl_ptr, std::string s) { - c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; - storage_impl->release_data_and_set_meta_custom_data_ptr_error_msg_(s); - }); - m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) { // NOLINTNEXTLINE(performance-no-int-to-ptr) c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;