mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[Inductor] Inplacing with Donated Buffer (#140113)"
This reverts commit eecc8e362c.
Reverted https://github.com/pytorch/pytorch/pull/140113 on behalf of https://github.com/BoyuanFeng due to break test_donated_buffer_inplace internally since donated_buffer = False if is_fbcode() else True ([comment](https://github.com/pytorch/pytorch/pull/140113#issuecomment-2501954300))
This commit is contained in:
parent
869d629c0f
commit
65dbd5cc2d
7 changed files with 17 additions and 180 deletions
|
|
@ -5198,31 +5198,6 @@ class CommonTemplate:
|
|||
if self.device != "cpu":
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_matmul_layer_norm(self):
|
||||
batch_size = 32
|
||||
seq_length = 50
|
||||
hidden_size = 256
|
||||
|
||||
inp = torch.randn(
|
||||
batch_size,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
requires_grad=True,
|
||||
device=self.device,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_size, hidden_size, requires_grad=True, device=self.device
|
||||
)
|
||||
|
||||
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
|
||||
|
||||
def foo(inp, weight):
|
||||
matmul_output = inp @ weight
|
||||
final_output = layer_norm(matmul_output)
|
||||
return final_output
|
||||
|
||||
self.common(foo, (inp, weight), check_lowp=False)
|
||||
|
||||
def test_transpose_add(self):
|
||||
def fn(a, b):
|
||||
return a.t() + b
|
||||
|
|
@ -12880,43 +12855,6 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
|||
self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0)
|
||||
self.assertEqual(fn_opt(*inps), fn(*inps))
|
||||
|
||||
def test_donated_buffer_inplace(self):
|
||||
batch_size = 32
|
||||
seq_length = 50
|
||||
hidden_size = 256
|
||||
|
||||
inp = torch.randn(
|
||||
batch_size,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
requires_grad=True,
|
||||
device=self.device,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_size, hidden_size, requires_grad=True, device=self.device
|
||||
)
|
||||
|
||||
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
|
||||
|
||||
def fn(inp, weight):
|
||||
matmul_output = inp @ weight
|
||||
final_output = layer_norm(matmul_output)
|
||||
return final_output
|
||||
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
def wrapper(inp, weight):
|
||||
return fn_opt(inp, weight).sum().backward()
|
||||
|
||||
_, code = run_and_get_code(wrapper, inp, weight)
|
||||
|
||||
if config.cpp_wrapper:
|
||||
# when using cpp_wrapper, backward triton code is in code[2]
|
||||
self.assertTrue("in_out_ptr" in code[2])
|
||||
else:
|
||||
# when not using cpp_wrapper, backward triton code is in code[1]
|
||||
self.assertTrue("in_out_ptr" in code[1])
|
||||
|
||||
class RNNTest(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
|
|
|
|||
|
|
@ -2120,11 +2120,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
def codegen_allocation(self, buffer: ir.Buffer):
|
||||
name = buffer.get_name()
|
||||
|
||||
if (
|
||||
name in V.graph.removed_buffers
|
||||
or name in self.allocated
|
||||
or isinstance(buffer, ir.DonatedBuffer)
|
||||
):
|
||||
if name in V.graph.removed_buffers or name in self.allocated:
|
||||
return
|
||||
self.allocated.add(name)
|
||||
if isinstance(
|
||||
|
|
@ -2178,12 +2174,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
name = input_buffer.get_name()
|
||||
return not (
|
||||
name in V.graph.removed_buffers
|
||||
or (
|
||||
name in V.graph.graph_inputs
|
||||
and not isinstance(
|
||||
V.graph.graph_inputs_original[name], ir.DonatedBuffer
|
||||
)
|
||||
)
|
||||
or name in V.graph.graph_inputs
|
||||
or name in V.graph.constants
|
||||
or name in V.graph.torchbind_constants
|
||||
or name in V.graph.never_reuse_buffers
|
||||
|
|
|
|||
|
|
@ -832,20 +832,6 @@ class CUDAGraphNode:
|
|||
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
|
||||
]
|
||||
|
||||
# (depth, offset) of live tensors which are alias of previous graph outputs
|
||||
self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [
|
||||
(
|
||||
self._is_alias_of_live_recorded_tensor(t)
|
||||
if isinstance(t, torch.Tensor)
|
||||
else None
|
||||
)
|
||||
for t in inputs
|
||||
]
|
||||
|
||||
# when replay, preserve the liveness of an input if it AliasesPriorGraphOutput
|
||||
# and also aliases an output of the current CUDAGraphNode
|
||||
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
|
||||
|
||||
self.static_input_idxs: List[int] = list(
|
||||
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
|
||||
)
|
||||
|
|
@ -1052,11 +1038,11 @@ class CUDAGraphNode:
|
|||
self.check_static_inputs_are_stable(new_inputs)
|
||||
|
||||
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
|
||||
new_inputs.clear()
|
||||
|
||||
self.run_graph()
|
||||
|
||||
outputs = self.reconstruct_outputs()
|
||||
new_inputs.clear()
|
||||
|
||||
if config.triton.fast_path_cudagraph_asserts:
|
||||
self.debug_check_invariants_after_invocation()
|
||||
|
|
@ -1275,12 +1261,6 @@ class CUDAGraphNode:
|
|||
path_ref = self._is_alias_of_live_recorded_tensor(o)
|
||||
if path_ref is not None:
|
||||
self._mark_prior_graph_output_as_aliased(path_ref)
|
||||
|
||||
for idx, inp_path_ref in enumerate(
|
||||
self.live_cudagraph_managed_path_refs
|
||||
):
|
||||
if path_ref == inp_path_ref:
|
||||
self.preserved_aliased_inputs[idx] = True
|
||||
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
|
||||
continue
|
||||
|
||||
|
|
@ -1687,8 +1667,7 @@ class CUDAGraphNode:
|
|||
# this invocation. it is too late to check after we've replayed the graph,
|
||||
# because we would have already written over their memory.
|
||||
for idx in self.cudagraph_managed_idxs:
|
||||
if not self.preserved_aliased_inputs[idx]:
|
||||
inputs[idx] = None # type: ignore[call-overload]
|
||||
inputs[idx] = None # type: ignore[call-overload]
|
||||
|
||||
torch._check(
|
||||
self._check_liveness(
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ from .exc import (
|
|||
)
|
||||
from .ir import (
|
||||
Constant,
|
||||
DonatedBuffer,
|
||||
FixedLayout,
|
||||
get_device_type,
|
||||
InputBuffer,
|
||||
|
|
@ -104,7 +103,6 @@ from .utils import (
|
|||
convert_shape_to_inductor,
|
||||
gather_origins,
|
||||
get_cloned_parameter_buffer_name,
|
||||
get_donated_idxs,
|
||||
get_sympy_Expr_dtype,
|
||||
is_same_tensor,
|
||||
maybe_get_suppress_shape_guards_ctx,
|
||||
|
|
@ -488,11 +486,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
# state used by for Kernel.workspace
|
||||
self.workspace_id = itertools.count()
|
||||
|
||||
# track the current placeholder index that we are processing
|
||||
self.placeholder_idx = -1
|
||||
|
||||
self.bw_donated_idxs = get_donated_idxs()
|
||||
|
||||
def has_feature(
|
||||
self,
|
||||
device: Union[torch._inductor.ir.IRNode, device, None],
|
||||
|
|
@ -970,7 +963,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
def placeholder(
|
||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||
) -> Union[Expr, TensorBox, None]:
|
||||
self.placeholder_idx += 1
|
||||
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||
target = self.qualify_name(target)
|
||||
if isinstance(example, SymTypes):
|
||||
|
|
@ -1001,27 +993,13 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
sizes, strides = self.static_sizes_strides(example)
|
||||
else:
|
||||
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
|
||||
|
||||
if (
|
||||
self.is_backward
|
||||
and self.bw_donated_idxs
|
||||
and self.placeholder_idx in self.bw_donated_idxs
|
||||
):
|
||||
tensor = TensorBox.create(
|
||||
DonatedBuffer(
|
||||
name=target,
|
||||
layout=FixedLayout(example.device, example.dtype, sizes, strides),
|
||||
)
|
||||
# TODO(jansel): handle input aliasing
|
||||
tensor = TensorBox.create(
|
||||
InputBuffer(
|
||||
name=target,
|
||||
layout=FixedLayout(example.device, example.dtype, sizes, strides),
|
||||
)
|
||||
else:
|
||||
# TODO(jansel): handle input aliasing
|
||||
tensor = TensorBox.create(
|
||||
InputBuffer(
|
||||
name=target,
|
||||
layout=FixedLayout(example.device, example.dtype, sizes, strides),
|
||||
)
|
||||
)
|
||||
|
||||
)
|
||||
self.graph_inputs[target] = tensor
|
||||
self.graph_input_names.append(target)
|
||||
self.graph_inputs_original[target] = tensor.data.data
|
||||
|
|
|
|||
|
|
@ -3832,16 +3832,6 @@ class InputBuffer(Buffer):
|
|||
return 1
|
||||
|
||||
|
||||
class DonatedBuffer(InputBuffer):
|
||||
"""
|
||||
Represents a donated buffer which is a saved tensor that is not alias to any
|
||||
fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
|
||||
reuse the input tensor memory during backward since it might be used in another
|
||||
function. However, donated buffer can be inplace reused during backward
|
||||
to save memory.
|
||||
"""
|
||||
|
||||
|
||||
class ConstantBuffer(InputBuffer):
|
||||
override_device: Optional[torch.device] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -125,16 +125,10 @@ class SchedulerBuffer:
|
|||
hasattr(V.kernel, "args")
|
||||
and self.get_name() in V.kernel.inplace_update_buffers
|
||||
):
|
||||
input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
|
||||
input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
|
||||
if input_buffer_name in self.scheduler.name_to_donated_buffer:
|
||||
input_buffer = self.scheduler.name_to_donated_buffer[
|
||||
input_buffer_name
|
||||
].node
|
||||
else:
|
||||
input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
|
||||
V.graph.wrapper_code.codegen_inplace_reuse(
|
||||
input_buffer,
|
||||
self.scheduler.name_to_buf[
|
||||
V.kernel.inplace_update_buffers[self.get_name()]
|
||||
].node,
|
||||
self.node,
|
||||
)
|
||||
else:
|
||||
|
|
@ -169,11 +163,6 @@ class SchedulerBuffer:
|
|||
return self.node.get_mutation_names()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class BaseSchedulerNode:
|
||||
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
|
||||
read_writes: dependencies.ReadWrites
|
||||
|
|
@ -453,12 +442,9 @@ class BaseSchedulerNode:
|
|||
continue
|
||||
|
||||
for read in self.read_writes.reads:
|
||||
input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
|
||||
if read.name in self.scheduler.name_to_donated_buffer:
|
||||
input_buf = self.scheduler.name_to_donated_buffer[read.name]
|
||||
else:
|
||||
input_buf = self.scheduler.name_to_buf.get(read.name)
|
||||
|
||||
input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get(
|
||||
read.name
|
||||
)
|
||||
if (
|
||||
input_buf
|
||||
and V.graph.wrapper_code.can_reuse(input_buf, self)
|
||||
|
|
@ -484,8 +470,7 @@ class BaseSchedulerNode:
|
|||
),
|
||||
)
|
||||
and not (
|
||||
input_buf.defining_op
|
||||
and isinstance(
|
||||
isinstance(
|
||||
input_buf.defining_op.node,
|
||||
(ir.FallbackKernel, ir.MultiOutput),
|
||||
)
|
||||
|
|
@ -1816,9 +1801,6 @@ class Scheduler:
|
|||
for node in self.nodes:
|
||||
node.prune_deps()
|
||||
|
||||
self.name_to_donated_buffer: Dict[
|
||||
str, SchedulerDonatedBuffer
|
||||
] = self.get_donated_buffers()
|
||||
self.name_to_node: Dict[str, BaseSchedulerNode] = {
|
||||
n.get_name(): n for n in self.nodes
|
||||
}
|
||||
|
|
@ -1902,17 +1884,6 @@ class Scheduler:
|
|||
}
|
||||
)
|
||||
|
||||
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
|
||||
name_to_donated_buf = {}
|
||||
for name in V.graph.graph_inputs_original:
|
||||
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
|
||||
name_to_donated_buf[name] = SchedulerDonatedBuffer(
|
||||
self,
|
||||
V.graph.graph_inputs_original[name],
|
||||
defining_op=None,
|
||||
)
|
||||
return name_to_donated_buf
|
||||
|
||||
@property
|
||||
def current_device(self) -> Optional[torch.device]:
|
||||
return V.graph.current_device
|
||||
|
|
@ -2189,9 +2160,6 @@ class Scheduler:
|
|||
for buf in node.get_outputs():
|
||||
buf.set_users(name_to_users[buf.get_name()].items)
|
||||
|
||||
for name in self.name_to_donated_buffer:
|
||||
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
|
||||
|
||||
def dead_node_elimination(self) -> None:
|
||||
"""
|
||||
Remove any nodes without users
|
||||
|
|
|
|||
|
|
@ -2200,10 +2200,3 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True):
|
|||
if cls is None:
|
||||
return wrap
|
||||
return wrap(cls)
|
||||
|
||||
|
||||
def get_donated_idxs() -> Optional[List[int]]:
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
if tracing_context is not None and tracing_context.fw_metadata:
|
||||
return tracing_context.fw_metadata.bw_donated_idxs
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in a new issue