mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Invalidate StorageImpl instances when tensor is overwritten with cudagraphs (#125264)"
This reverts commit 8390843eba.
Reverted https://github.com/pytorch/pytorch/pull/125264 on behalf of https://github.com/izaitsevfb due to breaks internal tests ([comment](https://github.com/pytorch/pytorch/pull/125264#issuecomment-2240516202))
This commit is contained in:
parent
35bf05561c
commit
7c299b46ca
6 changed files with 16 additions and 138 deletions
|
|
@ -40,14 +40,6 @@ void warnDeprecatedDataPtr() {
|
|||
"isinstance(tensor, FakeTensor).")
|
||||
}
|
||||
|
||||
[[noreturn]] void StorageImpl::throw_data_ptr_access_error() const {
|
||||
if (extra_meta_ && extra_meta_->custom_data_ptr_error_msg_) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
TORCH_CHECK(false, *extra_meta_->custom_data_ptr_error_msg_);
|
||||
}
|
||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
|
|
|||
|
|
@ -16,22 +16,9 @@
|
|||
|
||||
namespace c10 {
|
||||
|
||||
[[noreturn]] C10_API void throwNullDataPtrError();
|
||||
C10_API void throwNullDataPtrError();
|
||||
C10_API void warnDeprecatedDataPtr();
|
||||
|
||||
// Used in StorageImpl to store extra metadata.
|
||||
// Currently used only for storing a custom error message
|
||||
// used when throwing an exception when data_ptr is accessed.
|
||||
struct C10_API StorageExtraMeta {
|
||||
c10::optional<std::string> custom_data_ptr_error_msg_ = c10::nullopt;
|
||||
StorageExtraMeta() = default;
|
||||
StorageExtraMeta(const StorageExtraMeta& other) {
|
||||
if (other.custom_data_ptr_error_msg_) {
|
||||
custom_data_ptr_error_msg_ = other.custom_data_ptr_error_msg_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// A storage represents the underlying backing data buffer for a
|
||||
// tensor. This concept was inherited from the original Torch7
|
||||
// codebase; we'd kind of like to get rid of the concept
|
||||
|
|
@ -136,17 +123,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
const at::DataPtr& data_ptr() const {
|
||||
if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) {
|
||||
throw_data_ptr_access_error();
|
||||
}
|
||||
return data_ptr_;
|
||||
}
|
||||
|
||||
at::DataPtr& mutable_data_ptr() {
|
||||
if (C10_UNLIKELY(has_mutable_data_ptr_check_)) {
|
||||
if (throw_on_immutable_data_ptr_) {
|
||||
throw_data_ptr_access_error();
|
||||
}
|
||||
if (C10_UNLIKELY(has_data_ptr_check_)) {
|
||||
if (throw_on_mutable_data_ptr_) {
|
||||
throwNullDataPtrError();
|
||||
}
|
||||
|
|
@ -177,17 +158,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
const void* data() const {
|
||||
if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) {
|
||||
throw_data_ptr_access_error();
|
||||
}
|
||||
return data_ptr_.get();
|
||||
}
|
||||
|
||||
void* mutable_data() {
|
||||
if (C10_UNLIKELY(has_mutable_data_ptr_check_)) {
|
||||
if (throw_on_immutable_data_ptr_) {
|
||||
throw_data_ptr_access_error();
|
||||
}
|
||||
if (C10_UNLIKELY(has_data_ptr_check_)) {
|
||||
if (throw_on_mutable_data_ptr_) {
|
||||
throwNullDataPtrError();
|
||||
}
|
||||
|
|
@ -273,22 +248,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
StorageExtraMeta& get_extra_meta() {
|
||||
if (!extra_meta_) {
|
||||
extra_meta_ = std::make_unique<StorageExtraMeta>();
|
||||
}
|
||||
return *extra_meta_;
|
||||
}
|
||||
|
||||
[[noreturn]] void throw_data_ptr_access_error() const;
|
||||
|
||||
void release_data_and_set_meta_custom_data_ptr_error_msg_(
|
||||
c10::optional<std::string> s) {
|
||||
throw_on_immutable_data_ptr_ = true;
|
||||
get_extra_meta().custom_data_ptr_error_msg_ = std::move(s);
|
||||
refresh_has_data_ptr_check();
|
||||
}
|
||||
|
||||
void set_throw_on_mutable_data_ptr() {
|
||||
throw_on_mutable_data_ptr_ = true;
|
||||
refresh_has_data_ptr_check();
|
||||
|
|
@ -314,8 +273,8 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
|
||||
private:
|
||||
void refresh_has_data_ptr_check() {
|
||||
has_mutable_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
|
||||
warn_deprecated_on_mutable_data_ptr_ || throw_on_immutable_data_ptr_;
|
||||
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
|
||||
warn_deprecated_on_mutable_data_ptr_;
|
||||
}
|
||||
|
||||
inline bool is_cow() const {
|
||||
|
|
@ -339,16 +298,13 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
// All special checks in data/data_ptr calls are guarded behind this single
|
||||
// boolean. This is for performance: .data/.data_ptr calls are commonly in the
|
||||
// hot-path.
|
||||
bool has_mutable_data_ptr_check_ = false;
|
||||
bool has_data_ptr_check_ = false;
|
||||
// If we should throw when mutable_data_ptr() or mutable_data() is called.
|
||||
bool throw_on_mutable_data_ptr_ = false;
|
||||
// If we should throw when data_ptr() or data() is called.
|
||||
bool throw_on_immutable_data_ptr_ = false;
|
||||
// If we warn when mutable_data_ptr() or mutable_data() is called.
|
||||
bool warn_deprecated_on_mutable_data_ptr_ = false;
|
||||
Allocator* allocator_;
|
||||
impl::PyObjectSlot pyobj_slot_;
|
||||
std::unique_ptr<StorageExtraMeta> extra_meta_ = nullptr;
|
||||
};
|
||||
|
||||
// Declare StorageImpl create function pointer types.
|
||||
|
|
|
|||
|
|
@ -882,18 +882,15 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
def foo2(x):
|
||||
return x[2:]
|
||||
|
||||
x = torch.rand([10, 10], device="cuda", requires_grad=True)
|
||||
param_c = cdata(m.weight)
|
||||
for _ in range(3):
|
||||
x = torch.rand([10, 10], device="cuda", requires_grad=True)
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
out1, alias_1, alias_2 = foo(m, x)
|
||||
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
|
||||
|
||||
out2 = foo2(out1)
|
||||
out2.sum().backward()
|
||||
self.assertEqual(cdata(out1), cdata(out2))
|
||||
m.weight.grad = None
|
||||
m.bias.grad = None
|
||||
|
||||
node = self.curr_node()
|
||||
first_node = next(node._path_from_root)
|
||||
|
|
@ -1498,37 +1495,12 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
out = foo(inp)
|
||||
out2 = foo(inp)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "overwritten by a subsequent forward run."
|
||||
):
|
||||
with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
|
||||
out + out
|
||||
|
||||
foo(inp)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "overwritten by a subsequent forward run."
|
||||
):
|
||||
out2 + out2
|
||||
|
||||
def test_error_on_dealloc_use2(self):
|
||||
@torch.compile()
|
||||
def foo(x):
|
||||
return x * x * x
|
||||
|
||||
inp = torch.rand([4], device="cuda")
|
||||
out = foo(inp).detach()
|
||||
out2 = foo(inp).detach()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "overwritten by a subsequent forward run."
|
||||
):
|
||||
out + out
|
||||
|
||||
foo(inp)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "overwritten by a subsequent forward run."
|
||||
):
|
||||
with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
|
||||
out2 + out2
|
||||
|
||||
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
|
||||
|
|
@ -1555,7 +1527,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
|
||||
for _ in range(4):
|
||||
foo(inp).sum().backward()
|
||||
inp.grad = None
|
||||
|
||||
streams = {
|
||||
seg["stream"] for seg in get_all_cudagraph_segments()
|
||||
|
|
@ -1643,7 +1614,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
out2.sum().backward()
|
||||
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
|
||||
|
||||
ones.grad = None
|
||||
del out
|
||||
del out2
|
||||
|
||||
|
|
@ -2025,7 +1995,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
|
||||
for _ in range(3):
|
||||
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
|
||||
fn_compiled.param.grad = None
|
||||
|
||||
# Change static tensor address
|
||||
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
|
||||
|
|
@ -2062,13 +2031,11 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
|
||||
for _ in range(3):
|
||||
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
|
||||
fn_compiled.param.grad = None
|
||||
|
||||
for _ in range(5):
|
||||
# Change static tensor address
|
||||
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
|
||||
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
|
||||
fn_compiled.param.grad = None
|
||||
|
||||
FileCheck().check_count(
|
||||
"skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "
|
||||
|
|
|
|||
|
|
@ -1854,7 +1854,6 @@ def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Option
|
|||
def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ...
|
||||
def _storage_Use_Count(storage_ptr: _int) -> _int: ...
|
||||
def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ...
|
||||
def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ...
|
||||
def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ...
|
||||
def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ...
|
||||
|
||||
|
|
|
|||
|
|
@ -1862,24 +1862,22 @@ class CUDAGraphTreeManager:
|
|||
# mod2(mod1(x)).sum().backward()
|
||||
|
||||
self.running_forwards_with_pending_backwards = False
|
||||
self.mode: Optional[CompilationMode] = None
|
||||
|
||||
def run(self, new_inputs: List[Tensor], function_id: FunctionID):
|
||||
assert self.graph is not None, "Running CUDAGraph after shutdown"
|
||||
self.mode = self.id_to_mode[function_id]
|
||||
out = self._run(new_inputs, function_id)
|
||||
|
||||
# The forwards are only pending following invocation, not before
|
||||
if self.mode == CompilationMode.FORWARD:
|
||||
mode = self.id_to_mode[function_id]
|
||||
if mode == CompilationMode.FORWARD:
|
||||
self.running_forwards_with_pending_backwards = True
|
||||
elif self.mode == CompilationMode.BACKWARD:
|
||||
elif mode == CompilationMode.BACKWARD:
|
||||
self.running_forwards_with_pending_backwards = False
|
||||
|
||||
return out
|
||||
|
||||
def set_to_running_backward(self):
|
||||
self.running_forwards_with_pending_backwards = False
|
||||
self.mode = CompilationMode.BACKWARD
|
||||
|
||||
def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]:
|
||||
return (
|
||||
|
|
@ -2290,18 +2288,9 @@ class CUDAGraphTreeManager:
|
|||
def dealloc_current_path_weakrefs(self):
|
||||
# TODO: we could also allow the these weak refs to continue to be allocated,
|
||||
# but that adds some complications.
|
||||
run_type = "backward" if self.mode == CompilationMode.BACKWARD else "forward"
|
||||
|
||||
stack_map = {}
|
||||
for node in self.current_node._path_from_root:
|
||||
assert (
|
||||
len(node.tensor_weakrefs)
|
||||
== len(node.stack_traces)
|
||||
== len(node.outputs_weakrefs)
|
||||
)
|
||||
for t, stack_trace, stor_ref in zip(
|
||||
node.tensor_weakrefs, node.stack_traces, node.outputs_weakrefs
|
||||
):
|
||||
assert len(node.tensor_weakrefs) == len(node.stack_traces)
|
||||
for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces):
|
||||
ten = None if t is None else t()
|
||||
if ten is None:
|
||||
continue
|
||||
|
|
@ -2311,17 +2300,9 @@ class CUDAGraphTreeManager:
|
|||
if stack_trace
|
||||
else "[Could not find stack trace]"
|
||||
)
|
||||
|
||||
if stor_ref is not None and stor_ref() is not None:
|
||||
stack_map[stor_ref.data_ptr()] = stack_trace
|
||||
|
||||
ten = None if t is None else t()
|
||||
if ten is None:
|
||||
continue
|
||||
|
||||
msg = (
|
||||
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent "
|
||||
f"{run_type} run. Stack trace: {stack_trace}. "
|
||||
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. "
|
||||
f"Stack trace: {stack_trace}. "
|
||||
"To prevent overwriting, clone the tensor outside of torch.compile() "
|
||||
"or call torch.compiler.cudagraph_mark_step_begin() before each model invocation."
|
||||
)
|
||||
|
|
@ -2331,17 +2312,7 @@ class CUDAGraphTreeManager:
|
|||
for storage_ref in self.current_node.path_live_weakrefs():
|
||||
if storage_ref() and storage_ref.data_ptr() not in deleted:
|
||||
deleted.add(storage_ref.data_ptr())
|
||||
stack_trace = stack_map.get(
|
||||
storage_ref.data_ptr(), "[Could not find stack trace]"
|
||||
)
|
||||
msg = (
|
||||
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent "
|
||||
f"{run_type} run. Stack trace: {stack_trace}. "
|
||||
"To prevent overwriting, clone the tensor outside of torch.compile() "
|
||||
"or call torch.compiler.cudagraph_mark_step_begin() before each model invocation."
|
||||
)
|
||||
torch._C._free_And_Remove_DeleterFn(storage_ref())
|
||||
torch._C._set_storage_data_ptr_access_error_msg(storage_ref(), msg)
|
||||
|
||||
def clear_current_path_state_and_set_to_none(self):
|
||||
self.current_node.clear_path_state()
|
||||
|
|
|
|||
|
|
@ -1214,13 +1214,6 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
|||
->release_storage_and_set_meta_custom_data_ptr_error_msg_(s);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_set_storage_data_ptr_access_error_msg",
|
||||
[](size_t storage_impl_ptr, std::string s) {
|
||||
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
|
||||
storage_impl->release_data_and_set_meta_custom_data_ptr_error_msg_(s);
|
||||
});
|
||||
|
||||
m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
|
||||
|
|
|
|||
Loading…
Reference in a new issue