Revert "Invalidate StorageImpl instances when tensor is overwritten with cudagraphs (#125264)"

This reverts commit 8390843eba.

Reverted https://github.com/pytorch/pytorch/pull/125264 on behalf of https://github.com/izaitsevfb due to breaks internal tests ([comment](https://github.com/pytorch/pytorch/pull/125264#issuecomment-2240516202))
This commit is contained in:
PyTorch MergeBot 2024-07-19 22:58:51 +00:00
parent 35bf05561c
commit 7c299b46ca
6 changed files with 16 additions and 138 deletions

View file

@ -40,14 +40,6 @@ void warnDeprecatedDataPtr() {
"isinstance(tensor, FakeTensor).")
}
[[noreturn]] void StorageImpl::throw_data_ptr_access_error() const {
if (extra_meta_ && extra_meta_->custom_data_ptr_error_msg_) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
TORCH_CHECK(false, *extra_meta_->custom_data_ptr_error_msg_);
}
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View file

@ -16,22 +16,9 @@
namespace c10 {
[[noreturn]] C10_API void throwNullDataPtrError();
C10_API void throwNullDataPtrError();
C10_API void warnDeprecatedDataPtr();
// Used in StorageImpl to store extra metadata.
// Currently used only for storing a custom error message
// used when throwing an exception when data_ptr is accessed.
struct C10_API StorageExtraMeta {
c10::optional<std::string> custom_data_ptr_error_msg_ = c10::nullopt;
StorageExtraMeta() = default;
StorageExtraMeta(const StorageExtraMeta& other) {
if (other.custom_data_ptr_error_msg_) {
custom_data_ptr_error_msg_ = other.custom_data_ptr_error_msg_;
}
}
};
// A storage represents the underlying backing data buffer for a
// tensor. This concept was inherited from the original Torch7
// codebase; we'd kind of like to get rid of the concept
@ -136,17 +123,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
}
const at::DataPtr& data_ptr() const {
if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) {
throw_data_ptr_access_error();
}
return data_ptr_;
}
at::DataPtr& mutable_data_ptr() {
if (C10_UNLIKELY(has_mutable_data_ptr_check_)) {
if (throw_on_immutable_data_ptr_) {
throw_data_ptr_access_error();
}
if (C10_UNLIKELY(has_data_ptr_check_)) {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
@ -177,17 +158,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
}
const void* data() const {
if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) {
throw_data_ptr_access_error();
}
return data_ptr_.get();
}
void* mutable_data() {
if (C10_UNLIKELY(has_mutable_data_ptr_check_)) {
if (throw_on_immutable_data_ptr_) {
throw_data_ptr_access_error();
}
if (C10_UNLIKELY(has_data_ptr_check_)) {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
@ -273,22 +248,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
StorageExtraMeta& get_extra_meta() {
if (!extra_meta_) {
extra_meta_ = std::make_unique<StorageExtraMeta>();
}
return *extra_meta_;
}
[[noreturn]] void throw_data_ptr_access_error() const;
void release_data_and_set_meta_custom_data_ptr_error_msg_(
c10::optional<std::string> s) {
throw_on_immutable_data_ptr_ = true;
get_extra_meta().custom_data_ptr_error_msg_ = std::move(s);
refresh_has_data_ptr_check();
}
void set_throw_on_mutable_data_ptr() {
throw_on_mutable_data_ptr_ = true;
refresh_has_data_ptr_check();
@ -314,8 +273,8 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
private:
void refresh_has_data_ptr_check() {
has_mutable_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
warn_deprecated_on_mutable_data_ptr_ || throw_on_immutable_data_ptr_;
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
warn_deprecated_on_mutable_data_ptr_;
}
inline bool is_cow() const {
@ -339,16 +298,13 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
// All special checks in data/data_ptr calls are guarded behind this single
// boolean. This is for performance: .data/.data_ptr calls are commonly in the
// hot-path.
bool has_mutable_data_ptr_check_ = false;
bool has_data_ptr_check_ = false;
// If we should throw when mutable_data_ptr() or mutable_data() is called.
bool throw_on_mutable_data_ptr_ = false;
// If we should throw when data_ptr() or data() is called.
bool throw_on_immutable_data_ptr_ = false;
// If we warn when mutable_data_ptr() or mutable_data() is called.
bool warn_deprecated_on_mutable_data_ptr_ = false;
Allocator* allocator_;
impl::PyObjectSlot pyobj_slot_;
std::unique_ptr<StorageExtraMeta> extra_meta_ = nullptr;
};
// Declare StorageImpl create function pointer types.

View file

@ -882,18 +882,15 @@ if HAS_CUDA and not TEST_WITH_ASAN:
def foo2(x):
return x[2:]
x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight)
for _ in range(3):
x = torch.rand([10, 10], device="cuda", requires_grad=True)
torch.compiler.cudagraph_mark_step_begin()
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))
m.weight.grad = None
m.bias.grad = None
node = self.curr_node()
first_node = next(node._path_from_root)
@ -1498,37 +1495,12 @@ if HAS_CUDA and not TEST_WITH_ASAN:
out = foo(inp)
out2 = foo(inp)
with self.assertRaisesRegex(
Exception, "overwritten by a subsequent forward run."
):
with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
out + out
foo(inp)
with self.assertRaisesRegex(
Exception, "overwritten by a subsequent forward run."
):
out2 + out2
def test_error_on_dealloc_use2(self):
@torch.compile()
def foo(x):
return x * x * x
inp = torch.rand([4], device="cuda")
out = foo(inp).detach()
out2 = foo(inp).detach()
with self.assertRaisesRegex(
Exception, "overwritten by a subsequent forward run."
):
out + out
foo(inp)
with self.assertRaisesRegex(
Exception, "overwritten by a subsequent forward run."
):
with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
out2 + out2
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
@ -1555,7 +1527,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
for _ in range(4):
foo(inp).sum().backward()
inp.grad = None
streams = {
seg["stream"] for seg in get_all_cudagraph_segments()
@ -1643,7 +1614,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
out2.sum().backward()
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
ones.grad = None
del out
del out2
@ -2025,7 +1995,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
for _ in range(3):
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
fn_compiled.param.grad = None
# Change static tensor address
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
@ -2062,13 +2031,11 @@ if HAS_CUDA and not TEST_WITH_ASAN:
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
for _ in range(3):
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
fn_compiled.param.grad = None
for _ in range(5):
# Change static tensor address
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
fn_compiled.param.grad = None
FileCheck().check_count(
"skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "

View file

@ -1854,7 +1854,6 @@ def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Option
def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ...
def _storage_Use_Count(storage_ptr: _int) -> _int: ...
def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ...
def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ...
def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ...
def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ...

View file

@ -1862,24 +1862,22 @@ class CUDAGraphTreeManager:
# mod2(mod1(x)).sum().backward()
self.running_forwards_with_pending_backwards = False
self.mode: Optional[CompilationMode] = None
def run(self, new_inputs: List[Tensor], function_id: FunctionID):
assert self.graph is not None, "Running CUDAGraph after shutdown"
self.mode = self.id_to_mode[function_id]
out = self._run(new_inputs, function_id)
# The forwards are only pending following invocation, not before
if self.mode == CompilationMode.FORWARD:
mode = self.id_to_mode[function_id]
if mode == CompilationMode.FORWARD:
self.running_forwards_with_pending_backwards = True
elif self.mode == CompilationMode.BACKWARD:
elif mode == CompilationMode.BACKWARD:
self.running_forwards_with_pending_backwards = False
return out
def set_to_running_backward(self):
self.running_forwards_with_pending_backwards = False
self.mode = CompilationMode.BACKWARD
def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]:
return (
@ -2290,18 +2288,9 @@ class CUDAGraphTreeManager:
def dealloc_current_path_weakrefs(self):
# TODO: we could also allow the these weak refs to continue to be allocated,
# but that adds some complications.
run_type = "backward" if self.mode == CompilationMode.BACKWARD else "forward"
stack_map = {}
for node in self.current_node._path_from_root:
assert (
len(node.tensor_weakrefs)
== len(node.stack_traces)
== len(node.outputs_weakrefs)
)
for t, stack_trace, stor_ref in zip(
node.tensor_weakrefs, node.stack_traces, node.outputs_weakrefs
):
assert len(node.tensor_weakrefs) == len(node.stack_traces)
for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces):
ten = None if t is None else t()
if ten is None:
continue
@ -2311,17 +2300,9 @@ class CUDAGraphTreeManager:
if stack_trace
else "[Could not find stack trace]"
)
if stor_ref is not None and stor_ref() is not None:
stack_map[stor_ref.data_ptr()] = stack_trace
ten = None if t is None else t()
if ten is None:
continue
msg = (
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent "
f"{run_type} run. Stack trace: {stack_trace}. "
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. "
f"Stack trace: {stack_trace}. "
"To prevent overwriting, clone the tensor outside of torch.compile() "
"or call torch.compiler.cudagraph_mark_step_begin() before each model invocation."
)
@ -2331,17 +2312,7 @@ class CUDAGraphTreeManager:
for storage_ref in self.current_node.path_live_weakrefs():
if storage_ref() and storage_ref.data_ptr() not in deleted:
deleted.add(storage_ref.data_ptr())
stack_trace = stack_map.get(
storage_ref.data_ptr(), "[Could not find stack trace]"
)
msg = (
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent "
f"{run_type} run. Stack trace: {stack_trace}. "
"To prevent overwriting, clone the tensor outside of torch.compile() "
"or call torch.compiler.cudagraph_mark_step_begin() before each model invocation."
)
torch._C._free_And_Remove_DeleterFn(storage_ref())
torch._C._set_storage_data_ptr_access_error_msg(storage_ref(), msg)
def clear_current_path_state_and_set_to_none(self):
self.current_node.clear_path_state()

View file

@ -1214,13 +1214,6 @@ static void registerCudaPluggableAllocator(PyObject* module) {
->release_storage_and_set_meta_custom_data_ptr_error_msg_(s);
});
m.def(
"_set_storage_data_ptr_access_error_msg",
[](size_t storage_impl_ptr, std::string s) {
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
storage_impl->release_data_and_set_meta_custom_data_ptr_error_msg_(s);
});
m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;